Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit a98cbeb

Browse files
Adding initial support for DataFrameGroupBy.getitem (#622)
1 parent 83681fc commit a98cbeb

File tree

3 files changed

+166
-42
lines changed

3 files changed

+166
-42
lines changed

sdc/datatypes/hpat_pandas_groupby_functions.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import pandas
3232
import numba
3333
import numpy
34+
import operator
3435
import sdc
3536

3637
from numba import types
@@ -39,41 +40,94 @@
3940
from numba.targets.registry import cpu_target
4041
from numba.typed import List, Dict
4142
from numba.typing import signature
43+
from numba.special import literally
4244

4345
from sdc.datatypes.common_functions import sdc_arrays_argsort, _sdc_asarray, _sdc_take
4446
from sdc.datatypes.hpat_pandas_groupby_types import DataFrameGroupByType
4547
from sdc.utilities.sdc_typing_utils import TypeChecker, kwsparams2list, sigparams2list
46-
from sdc.utilities.utils import sdc_overload_method, sdc_overload_attribute
48+
from sdc.utilities.utils import sdc_overload, sdc_overload_method, sdc_overload_attribute
4749
from sdc.hiframes.pd_dataframe_ext import get_dataframe_data
4850
from sdc.hiframes.pd_series_type import SeriesType
51+
from sdc.str_ext import string_type
4952

5053

5154
@intrinsic
52-
def init_dataframe_groupby(typingctx, parent, column_id, data, sort):
55+
def init_dataframe_groupby(typingctx, parent, column_id, data, sort, target_columns=None):
5356

57+
target_columns = types.none if target_columns is None else target_columns
58+
if isinstance(target_columns, types.NoneType):
59+
target_not_specified = True
60+
selected_col_names = tuple([a for i, a in enumerate(parent.columns) if i != column_id.literal_value])
61+
else:
62+
target_not_specified = False
63+
selected_col_names = tuple([a.literal_value for a in target_columns])
64+
65+
n_target_cols = len(selected_col_names)
5466
def codegen(context, builder, signature, args):
55-
parent_val, column_id_val, data_val, sort_val = args
67+
parent_val, column_id_val, data_val, sort_val, target_columns = args
5668
# create series struct and store values
5769
groupby_obj = cgutils.create_struct_proxy(
5870
signature.return_type)(context, builder)
5971
groupby_obj.parent = parent_val
6072
groupby_obj.col_id = column_id_val
6173
groupby_obj.data = data_val
6274
groupby_obj.sort = sort_val
75+
groupby_obj.target_default = context.get_constant(types.bool_, target_not_specified)
76+
77+
column_strs = [numba.unicode.make_string_from_constant(
78+
context, builder, string_type, c) for c in selected_col_names]
79+
column_tup = context.make_tuple(
80+
builder, types.UniTuple(string_type, n_target_cols), column_strs)
81+
82+
groupby_obj.target_columns = column_tup
6383

6484
# increase refcount of stored values
6585
if context.enable_nrt:
6686
context.nrt.incref(builder, signature.args[0], parent_val)
6787
context.nrt.incref(builder, signature.args[1], column_id_val)
6888
context.nrt.incref(builder, signature.args[2], data_val)
89+
for var in column_strs:
90+
context.nrt.incref(builder, string_type, var)
6991

7092
return groupby_obj._getvalue()
7193

72-
ret_typ = DataFrameGroupByType(parent, column_id)
73-
sig = signature(ret_typ, parent, column_id, data, sort)
94+
ret_typ = DataFrameGroupByType(parent, column_id, selected_col_names)
95+
sig = signature(ret_typ, parent, column_id, data, sort, target_columns)
7496
return sig, codegen
7597

7698

99+
@sdc_overload(operator.getitem)
100+
def sdc_pandas_dataframe_getitem(self, idx):
101+
102+
if not isinstance(self, DataFrameGroupByType):
103+
return None
104+
105+
idx_is_literal_str = isinstance(idx, types.StringLiteral)
106+
if (idx_is_literal_str
107+
or (isinstance(idx, types.Tuple)
108+
and all(isinstance(a, types.StringLiteral) for a in idx))):
109+
110+
col_id_literal = self.col_id.literal_value
111+
idx_literal = idx.literal_value if idx_is_literal_str else None
112+
def sdc_pandas_dataframe_getitem_common_impl(self, idx):
113+
114+
_idx = (idx_literal, ) if idx_is_literal_str == True else idx # noqa
115+
# calling getitem twice raises IndexError, just as in pandas
116+
if not self._target_default:
117+
raise IndexError("DataFrame.GroupBy.getitem: Columns already selected")
118+
return init_dataframe_groupby(self._parent, col_id_literal, self._data, self._sort, _idx)
119+
120+
return sdc_pandas_dataframe_getitem_common_impl
121+
122+
if isinstance(idx, types.UnicodeType):
123+
def sdc_pandas_dataframe_getitem_idx_unicode_str_impl(self, idx):
124+
# just call literally as it will raise and compilation will continue via common impl
125+
return literally(idx)
126+
return sdc_pandas_dataframe_getitem_idx_unicode_str_impl
127+
128+
return None
129+
130+
77131
def _sdc_pandas_groupby_generic_func_codegen(func_name, columns, func_params, defaults, impl_params):
78132

79133
all_params_as_str = ', '.join(sigparams2list(func_params, defaults))
@@ -155,7 +209,8 @@ def sdc_pandas_groupby_apply_func(self, func_name, func_args, defaults=None, imp
155209
df_column_types = self.parent.data
156210
df_column_names = self.parent.columns
157211
by_column_id = self.col_id.literal_value
158-
subject_columns = [(name, i) for i, name in enumerate(df_column_names) if i != by_column_id]
212+
selected_cols_set = set(self.target_columns)
213+
subject_columns = [(name, i) for i, name in enumerate(df_column_names) if name in selected_cols_set]
159214

160215
# resolve types of result dataframe columns
161216
res_arrays_dtypes = tuple(

sdc/datatypes/hpat_pandas_groupby_types.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626

2727

2828
import numba
29-
from numba import types, cgutils
29+
from numba import types
3030
from numba.extending import (models, register_model, make_attribute_wrapper)
31-
from numba.typed import Dict, List
3231
from sdc.str_ext import string_type
3332

3433

@@ -37,15 +36,16 @@ class DataFrameGroupByType(types.Type):
3736
Type definition for DataFrameGroupBy functions handling.
3837
"""
3938

40-
def __init__(self, parent, col_id):
39+
def __init__(self, parent, col_id, target_columns):
4140
self.parent = parent
4241
self.col_id = col_id
42+
self.target_columns = target_columns
4343
super(DataFrameGroupByType, self).__init__(
44-
name="DataFrameGroupByType({}, {})".format(parent, col_id))
44+
name="DataFrameGroupByType({}, {})".format(parent, col_id, target_columns))
4545

4646
@property
4747
def key(self):
48-
return self.parent, self.col_id
48+
return self.parent, self.col_id, self.target_columns
4949

5050

5151
@register_model(DataFrameGroupByType)
@@ -56,11 +56,15 @@ def __init__(self, dmm, fe_type):
5656
by_series_dtype,
5757
types.containers.ListType(types.int64)
5858
)
59+
60+
n_target_cols = len(fe_type.target_columns)
5961
members = [
6062
('parent', fe_type.parent),
6163
('col_id', types.int64),
6264
('data', ty_data),
63-
('sort', types.bool_)
65+
('sort', types.bool_),
66+
('target_default', types.bool_),
67+
('target_columns', types.UniTuple(string_type, n_target_cols))
6468
]
6569
super(DataFrameGroupByModel, self).__init__(dmm, fe_type, members)
6670

@@ -69,3 +73,5 @@ def __init__(self, dmm, fe_type):
6973
make_attribute_wrapper(DataFrameGroupByType, 'col_id', '_col_id')
7074
make_attribute_wrapper(DataFrameGroupByType, 'data', '_data')
7175
make_attribute_wrapper(DataFrameGroupByType, 'sort', '_sort')
76+
make_attribute_wrapper(DataFrameGroupByType, 'target_default', '_target_default')
77+
make_attribute_wrapper(DataFrameGroupByType, 'target_columns', '_target_columns')

sdc/tests/test_groupby.py

Lines changed: 93 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -258,45 +258,53 @@ def test_impl(df):
258258
# np.testing.assert_array_equal(hpat_func(df), test_impl(df))
259259
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
260260

261-
@skip_numba_jit
261+
@skip_numba_jit("BUG: SDC impl of Series.sum returns float64 on as series of ints")
262262
def test_agg_seq_sum(self):
263263
def test_impl(df):
264-
A = df.groupby('A')['B'].sum()
265-
return A.values
264+
return df.groupby('A')['B'].sum()
266265

267266
hpat_func = self.jit(test_impl)
268267
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
269-
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
268+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
269+
result = hpat_func(df)
270+
result_ref = pd.DataFrame(test_impl(df))
271+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
270272

271-
@skip_numba_jit
273+
@skip_sdc_jit("Old-style implementation returns ndarray, not a Series")
272274
def test_agg_seq_count(self):
273275
def test_impl(df):
274-
A = df.groupby('A')['B'].count()
275-
return A.values
276+
return df.groupby('A')['B'].count()
276277

277278
hpat_func = self.jit(test_impl)
278279
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
279-
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
280+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
281+
result = hpat_func(df)
282+
result_ref = pd.DataFrame(test_impl(df))
283+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
280284

281-
@skip_numba_jit
285+
@skip_sdc_jit("Old-style implementation returns ndarray, not a Series")
282286
def test_agg_seq_mean(self):
283287
def test_impl(df):
284-
A = df.groupby('A')['B'].mean()
285-
return A.values
288+
return df.groupby('A')['B'].mean()
286289

287290
hpat_func = self.jit(test_impl)
288291
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
289-
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
292+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
293+
result = hpat_func(df)
294+
result_ref = pd.DataFrame(test_impl(df))
295+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
290296

291-
@skip_numba_jit
297+
@skip_sdc_jit("Old-style implementation returns ndarray, not a Series")
292298
def test_agg_seq_min(self):
293299
def test_impl(df):
294-
A = df.groupby('A')['B'].min()
295-
return A.values
300+
return df.groupby('A')['B'].min()
296301

297302
hpat_func = self.jit(test_impl)
298303
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
299-
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
304+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
305+
result = hpat_func(df)
306+
result_ref = pd.DataFrame(test_impl(df))
307+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
300308

301309
@skip_numba_jit
302310
def test_agg_seq_min_date(self):
@@ -308,15 +316,17 @@ def test_impl(df):
308316
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': pd.date_range('2019-1-3', '2019-1-9')})
309317
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
310318

311-
@skip_numba_jit
319+
@skip_sdc_jit("Old-style implementation returns ndarray, not a Series")
312320
def test_agg_seq_max(self):
313321
def test_impl(df):
314-
A = df.groupby('A')['B'].max()
315-
return A.values
322+
return df.groupby('A')['B'].max()
316323

317324
hpat_func = self.jit(test_impl)
318325
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
319-
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
326+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
327+
result = hpat_func(df)
328+
result_ref = pd.DataFrame(test_impl(df))
329+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
320330

321331
@skip_numba_jit
322332
def test_agg_seq_all_col(self):
@@ -338,37 +348,43 @@ def test_impl(df):
338348
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
339349
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
340350

341-
@skip_numba_jit
351+
@skip_sdc_jit("Old-style implementation returns ndarray, not a Series")
342352
def test_agg_seq_prod(self):
343353
def test_impl(df):
344-
A = df.groupby('A')['B'].prod()
345-
return A.values
354+
return df.groupby('A')['B'].prod()
346355

347356
hpat_func = self.jit(test_impl)
348357
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
349-
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
358+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
359+
result = hpat_func(df)
360+
result_ref = pd.DataFrame(test_impl(df))
361+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
350362

351363
@skip_sdc_jit
352364
@skip_numba_jit
353365
def test_agg_seq_var(self):
354366
def test_impl(df):
355-
A = df.groupby('A')['B'].var()
356-
return A.values
367+
return df.groupby('A')['B'].var()
357368

358369
hpat_func = self.jit(test_impl)
359370
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
360-
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
371+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
372+
result = hpat_func(df)
373+
result_ref = pd.DataFrame(test_impl(df))
374+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
361375

362376
@skip_sdc_jit
363377
@skip_numba_jit
364378
def test_agg_seq_std(self):
365379
def test_impl(df):
366-
A = df.groupby('A')['B'].std()
367-
return A.values
380+
return df.groupby('A')['B'].std()
368381

369382
hpat_func = self.jit(test_impl)
370383
df = pd.DataFrame({'A': [2, 1, 1, 1, 2, 2, 1], 'B': [-8, 2, 3, 1, 5, 6, 7]})
371-
self.assertEqual(set(hpat_func(df)), set(test_impl(df)))
384+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
385+
result = hpat_func(df)
386+
result_ref = pd.DataFrame(test_impl(df))
387+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
372388

373389
@skip_numba_jit
374390
def test_agg_seq_multiselect(self):
@@ -661,6 +677,53 @@ def test_impl(df):
661677
hpat_func = self.jit(test_impl)
662678
pd.testing.assert_frame_equal(hpat_func(df), test_impl(df))
663679

680+
def test_dataframe_groupby_getitem_literal_tuple(self):
681+
def test_impl(df):
682+
return df.groupby('A')['B', 'C'].count()
683+
hpat_func = self.jit(test_impl)
684+
685+
df = pd.DataFrame(_default_df_numeric_data)
686+
result = hpat_func(df)
687+
result_ref = test_impl(df)
688+
# TODO: implement index classes, as current indexes do not have names
689+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
690+
691+
def test_dataframe_groupby_getitem_literal_str(self):
692+
def test_impl(df):
693+
return df.groupby('C')['B'].count()
694+
hpat_func = self.jit(test_impl)
695+
696+
df = pd.DataFrame(_default_df_numeric_data)
697+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
698+
result = hpat_func(df)
699+
result_ref = pd.DataFrame(test_impl(df))
700+
# TODO: implement index classes, as current indexes do not have names
701+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
702+
703+
def test_dataframe_groupby_getitem_unicode_str(self):
704+
def test_impl(df, col_name):
705+
return df.groupby('A')[col_name].count()
706+
hpat_func = self.jit(test_impl)
707+
708+
df = pd.DataFrame(_default_df_numeric_data)
709+
col_name = 'C'
710+
# pandas returns groupby.generic.SeriesGroupBy object in this case, hence align result_ref
711+
result = hpat_func(df, col_name)
712+
result_ref = pd.DataFrame(test_impl(df, col_name))
713+
# TODO: implement index classes, as current indexes do not have names
714+
pd.testing.assert_frame_equal(result, result_ref, check_names=False)
715+
716+
def test_dataframe_groupby_getitem_repeated(self):
717+
def test_impl(df):
718+
return df.groupby('A')['B', 'C']['D']
719+
hpat_func = self.jit(test_impl)
720+
721+
df = pd.DataFrame(_default_df_numeric_data)
722+
with self.assertRaises(Exception) as context:
723+
test_impl(df)
724+
pandas_exception = context.exception
725+
726+
self.assertRaises(type(pandas_exception), hpat_func, df)
664727

665728
if __name__ == "__main__":
666729
unittest.main()

0 commit comments

Comments
 (0)