From b2425aeb6b665007a15d85342ea9615a715ad180 Mon Sep 17 00:00:00 2001 From: Alistair Johnson Date: Fri, 20 Apr 2018 09:36:23 -0400 Subject: [PATCH 1/2] making column checking robust to passing pandas index --- tableone.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tableone.py b/tableone.py index ac3f005..bbc5700 100644 --- a/tableone.py +++ b/tableone.py @@ -84,8 +84,10 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, nonnormal = [nonnormal] # if columns not specified, use all columns - if not columns: + if type(columns) == type(None): columns = data.columns.get_values() + elif 'pandas.core.indexes' in str(type(columns)): + columns = columns.get_values() # check that the columns exist in the dataframe if not set(columns).issubset(data.columns): @@ -98,7 +100,11 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, raise InputError('Input contains duplicate columns: {}'.format(dups)) # if categorical not specified, try to identify categorical - if not categorical and type(categorical) != list: + if type(columns) == type(None): + categorical = self._detect_categorical_columns(data[columns]) + elif 'pandas.core.indexes' in str(type(categorical)): + categorical = categorical.get_values() + elif type(categorical) != list: categorical = self._detect_categorical_columns(data[columns]) if pval and not groupby: @@ -267,7 +273,7 @@ def _normaltest(self,x): Compute test for normal distribution. Null hypothesis: x comes from a normal distribution - p < alpha suggests the null hypothesis can be rejected. + p < alpha suggests the null hypothesis can be rejected. """ stat,p = stats.normaltest(x.values, nan_policy='omit') return p From ac7345b51da2707ac32bafcf4229f0bf947f77e9 Mon Sep 17 00:00:00 2001 From: Alistair Johnson Date: Fri, 20 Apr 2018 10:29:23 -0400 Subject: [PATCH 2/2] explicitly convert categorical/columns to list if not recognized --- tableone.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tableone.py b/tableone.py index bbc5700..61155bb 100644 --- a/tableone.py +++ b/tableone.py @@ -88,6 +88,8 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, columns = data.columns.get_values() elif 'pandas.core.indexes' in str(type(columns)): columns = columns.get_values() + else: + columns = list(columns) # check that the columns exist in the dataframe if not set(columns).issubset(data.columns): @@ -104,13 +106,13 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, categorical = self._detect_categorical_columns(data[columns]) elif 'pandas.core.indexes' in str(type(categorical)): categorical = categorical.get_values() - elif type(categorical) != list: - categorical = self._detect_categorical_columns(data[columns]) + else: + categorical = list(columns) if pval and not groupby: raise InputError("If pval=True then the groupby must be specified.") - self._columns = list(columns) + self._columns = columns self._isnull = isnull self._continuous = [c for c in columns if c not in categorical + [groupby]] self._categorical = categorical