Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 3ca5a30

Browse files
Initial support of DataFrame.GroupBy by single literal column (#590)
* Initial support of DataFrame.GroupBy by single literal column * Fixing PEP and skipping one failed test * Applying review comments for tests
1 parent bc86634 commit 3ca5a30

11 files changed

+924
-251
lines changed

sdc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import sdc.datatypes.hpat_pandas_series_rolling_functions
4545
import sdc.datatypes.hpat_pandas_seriesgroupby_functions
4646
import sdc.datatypes.hpat_pandas_stringmethods_functions
47+
import sdc.datatypes.hpat_pandas_groupby_functions
4748

4849
from ._version import get_versions
4950

sdc/datatypes/common_functions.py

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@
4141
from numba import numpy_support
4242

4343
import sdc
44+
from sdc.hiframes.api import isna
4445
from sdc.hiframes.pd_series_type import SeriesType
45-
from sdc.str_arr_ext import (
46-
append_string_array_to, cp_str_list_to_array, num_total_chars,
47-
pre_alloc_string_array, str_arr_is_na, str_arr_set_na, string_array_type
48-
)
46+
from sdc.str_arr_type import string_array_type
47+
from sdc.str_arr_ext import (num_total_chars, append_string_array_to,
48+
str_arr_is_na, pre_alloc_string_array, str_arr_set_na, string_array_type,
49+
cp_str_list_to_array, create_str_arr_from_list, get_utf8_size)
4950
from sdc.utilities.utils import sdc_overload, sdc_register_jitable
5051
from sdc.utilities.sdc_typing_utils import (find_common_dtype_from_numpy_dtypes,
5152
TypeChecker)
@@ -483,18 +484,21 @@ def sdc_arrays_argsort(A, kind='quicksort'):
483484

484485
@sdc_overload(sdc_arrays_argsort, jit_options={'parallel': False})
485486
def sdc_arrays_argsort_overload(A, kind='quicksort'):
486-
"""Function overloading argsort for different 1D array types"""
487+
"""Function providing pandas argsort implementation for different 1D array types"""
487488

488489
# kind is not known at compile time, so get this function here and use in impl if needed
489490
quicksort_func = quicksort.make_jit_quicksort().run_quicksort
490491

492+
kind_is_default = isinstance(kind, str)
491493
if isinstance(A, types.Array):
492-
def _sdc_arrays_argsort_numeric_impl(A, kind='quicksort'):
493-
return numpy.argsort(A, kind=kind)
494-
return _sdc_arrays_argsort_numeric_impl
494+
def _sdc_arrays_argsort_array_impl(A, kind='quicksort'):
495+
_kind = 'quicksort' if kind_is_default == True else kind # noqa
496+
return numpy.argsort(A, kind=_kind)
497+
498+
return _sdc_arrays_argsort_array_impl
495499

496500
elif A == string_array_type:
497-
def _sdc_arrays_argsort_str_impl(A, kind='quicksort'):
501+
def _sdc_arrays_argsort_str_arr_impl(A, kind='quicksort'):
498502

499503
nan_mask = sdc.hiframes.api.get_nan_mask(A)
500504
idx = numpy.arange(len(A))
@@ -515,7 +519,10 @@ def _sdc_arrays_argsort_str_impl(A, kind='quicksort'):
515519
argsorted.extend(old_nan_positions)
516520
return numpy.asarray(argsorted, dtype=numpy.int32)
517521

518-
return _sdc_arrays_argsort_str_impl
522+
return _sdc_arrays_argsort_str_arr_impl
523+
524+
elif isinstance(A, types.List):
525+
return None
519526

520527
return None
521528

@@ -591,3 +598,78 @@ def _sdc_pandas_series_align_impl(series, other, size='max', finiteness=False):
591598
return aligned, aligned_other
592599

593600
return _sdc_pandas_series_align_impl
601+
602+
603+
def _sdc_asarray(data):
604+
pass
605+
606+
607+
@sdc_overload(_sdc_asarray, jit_options={'parallel': True})
608+
def _sdc_asarray_overload(data):
609+
610+
# TODO: extend with other types
611+
if not isinstance(data, types.List):
612+
return None
613+
614+
if isinstance(data.dtype, types.UnicodeType):
615+
def _sdc_asarray_impl(data):
616+
return create_str_arr_from_list(data)
617+
618+
return _sdc_asarray_impl
619+
620+
else:
621+
result_dtype = data.dtype
622+
623+
def _sdc_asarray_impl(data):
624+
# TODO: check if elementwise copy is needed at all
625+
res_size = len(data)
626+
res_arr = numpy.empty(res_size, dtype=result_dtype)
627+
for i in numba.prange(res_size):
628+
res_arr[i] = data[i]
629+
return res_arr
630+
631+
return _sdc_asarray_impl
632+
633+
return None
634+
635+
636+
def _sdc_take(data, indexes):
637+
pass
638+
639+
640+
@sdc_overload(_sdc_take, jit_options={'parallel': True})
641+
def _sdc_take_overload(data, indexes):
642+
643+
if isinstance(data, types.Array):
644+
arr_dtype = data.dtype
645+
646+
def _sdc_take_array_impl(data, indexes):
647+
res_size = len(indexes)
648+
res_arr = numpy.empty(res_size, dtype=arr_dtype)
649+
for i in numba.prange(res_size):
650+
res_arr[i] = data[indexes[i]]
651+
return res_arr
652+
653+
return _sdc_take_array_impl
654+
655+
elif data == string_array_type:
656+
def _sdc_take_str_arr_impl(data, indexes):
657+
res_size = len(indexes)
658+
nan_mask = numpy.zeros(res_size, dtype=numpy.bool_)
659+
num_total_bytes = 0
660+
for i in numba.prange(res_size):
661+
num_total_bytes += get_utf8_size(data[indexes[i]])
662+
if isna(data, indexes[i]):
663+
nan_mask[i] = True
664+
665+
res_arr = pre_alloc_string_array(res_size, num_total_bytes)
666+
for i in numpy.arange(res_size):
667+
res_arr[i] = data[indexes[i]]
668+
if nan_mask[i]:
669+
str_arr_set_na(res_arr, i)
670+
671+
return res_arr
672+
673+
return _sdc_take_str_arr_impl
674+
675+
return None

sdc/datatypes/hpat_pandas_dataframe_functions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
from numba import types
4040
from numba.special import literally
41+
from numba.typed import List, Dict
42+
4143
from sdc.hiframes.pd_dataframe_ext import DataFrameType
4244
from sdc.hiframes.pd_series_type import SeriesType
4345
from sdc.utilities.sdc_typing_utils import (TypeChecker, check_index_is_numeric,
@@ -50,8 +52,10 @@
5052
from sdc.datatypes.hpat_pandas_dataframe_rolling_types import _hpat_pandas_df_rolling_init
5153
from sdc.datatypes.hpat_pandas_rolling_types import (
5254
gen_sdc_pandas_rolling_overload_body, sdc_pandas_rolling_docstring_tmpl)
55+
from sdc.datatypes.hpat_pandas_groupby_functions import init_dataframe_groupby
5356
from sdc.hiframes.pd_dataframe_ext import get_dataframe_data
5457
from sdc.utilities.utils import sdc_overload, sdc_overload_method, sdc_overload_attribute
58+
from sdc.hiframes.api import isna
5559

5660

5761
@sdc_overload_attribute(DataFrameType, 'index')
@@ -1337,3 +1341,32 @@ def pct_change_overload(df, periods=1, fill_method='pad', limit=None, freq=None)
13371341
ser_par = {'periods': 'periods', 'fill_method': 'fill_method', 'limit': 'limit', 'freq': 'freq'}
13381342

13391343
return sdc_pandas_dataframe_apply_columns(df, name, params, ser_par)
1344+
1345+
1346+
@sdc_overload_method(DataFrameType, 'groupby')
1347+
def sdc_pandas_dataframe_groupby(self, by=None, axis=0, level=None, as_index=True, sort=True,
1348+
group_keys=True, squeeze=False, observed=False):
1349+
1350+
if not isinstance(by, types.StringLiteral):
1351+
return None
1352+
1353+
column_id = self.columns.index(by.literal_value)
1354+
list_type = types.ListType(types.int64)
1355+
by_type = self.data[column_id].dtype
1356+
1357+
def sdc_pandas_dataframe_groupby_impl(self, by=None, axis=0, level=None, as_index=True, sort=True,
1358+
group_keys=True, squeeze=False, observed=False):
1359+
1360+
grouped = Dict.empty(by_type, list_type)
1361+
by_column_data = get_dataframe_data(self, column_id)
1362+
for i in numpy.arange(len(by_column_data)):
1363+
if isna(by_column_data, i):
1364+
continue
1365+
value = by_column_data[i]
1366+
group_list = grouped.get(value, List.empty_list(types.int64))
1367+
group_list.append(i)
1368+
grouped[value] = group_list
1369+
1370+
return init_dataframe_groupby(self, column_id, grouped, sort)
1371+
1372+
return sdc_pandas_dataframe_groupby_impl

sdc/datatypes/hpat_pandas_dataframe_rolling_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from numba.types import (float64, Boolean, Integer, Number, Omitted,
3030
NoneType, StringLiteral, UnicodeType)
31-
from sdc.utilities.sdc_typing_utils import TypeChecker, params2list
31+
from sdc.utilities.sdc_typing_utils import TypeChecker, kwsparams2list
3232
from sdc.datatypes.hpat_pandas_dataframe_rolling_types import DataFrameRollingType
3333
from sdc.hiframes.pd_dataframe_ext import get_dataframe_data
3434
from sdc.hiframes.pd_dataframe_type import DataFrameType
@@ -95,7 +95,7 @@ def df_rolling_method_other_df_codegen(method_name, self, other, args=None, kws=
9595

9696
rolling_params = df_rolling_params_codegen()
9797
method_kws = {k: k for k in kwargs}
98-
impl_params = ['self'] + args + params2list(kwargs)
98+
impl_params = ['self'] + args + kwsparams2list(kwargs)
9999
impl_params_as_str = ', '.join(impl_params)
100100

101101
data_columns = {col: idx for idx, col in enumerate(self.data.columns)}
@@ -132,7 +132,7 @@ def df_rolling_method_other_df_codegen(method_name, self, other, args=None, kws=
132132
if col in common_columns:
133133
other_series = f'other_series_{col}'
134134
method_kws['other'] = other_series
135-
method_params = ', '.join(args + params2list(method_kws))
135+
method_params = ', '.join(args + kwsparams2list(method_kws))
136136
func_lines += [
137137
f' data_{col} = get_dataframe_data(self._data, {data_columns[col]})',
138138
f' other_data_{col} = get_dataframe_data(other, {other_columns[col]})',
@@ -189,7 +189,7 @@ def df_rolling_method_other_none_codegen(method_name, self, args=None, kws=None)
189189
args = args or []
190190
kwargs = kws or {}
191191

192-
impl_params = ['self'] + args + params2list(kwargs)
192+
impl_params = ['self'] + args + kwsparams2list(kwargs)
193193
impl_params_as_str = ', '.join(impl_params)
194194

195195
impl_name = f'_df_rolling_{_method_name}_other_none_impl'
@@ -223,7 +223,7 @@ def df_rolling_method_codegen(method_name, self, args=None, kws=None):
223223
args = args or []
224224
kwargs = kws or {}
225225

226-
impl_params = ['self'] + args + params2list(kwargs)
226+
impl_params = ['self'] + args + kwsparams2list(kwargs)
227227
impl_params_as_str = ', '.join(impl_params)
228228

229229
impl_name = f'_df_rolling_{method_name}_impl'

0 commit comments

Comments
 (0)