Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 4dd90c4

Browse files
authored
Overload df.getitem with bool array idx (#587)
1 parent 3ca5a30 commit 4dd90c4

File tree

2 files changed

+139
-63
lines changed

2 files changed

+139
-63
lines changed

sdc/datatypes/hpat_pandas_dataframe_functions.py

Lines changed: 115 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,53 +1063,32 @@ def sdc_pandas_dataframe_drop_impl(df, _func_name, args, columns):
10631063
return sdc_pandas_dataframe_drop_impl(df, _func_name, args, columns)
10641064

10651065

1066-
def df_getitem_bool_series_idx_main_codelines(self, idx):
1067-
"""Generate main code lines for df.getitem"""
1068-
func_lines = [' self_length = len(get_dataframe_data(self, 0))',
1069-
' trimmed_idx_data = idx._data[:self_length]']
1070-
1071-
if isinstance(self.index, types.NoneType):
1072-
func_lines += [' self_index = numpy.arange(self_length)']
1073-
else:
1074-
func_lines += [' self_index = self._index']
1075-
1076-
results = []
1077-
for i, col in enumerate(self.columns):
1078-
res_data = f'res_data_{i}'
1079-
func_lines += [
1080-
f' data_{i} = get_dataframe_data(self, {i})',
1081-
f' series = pandas.Series(data_{i}, index=self_index, name="{col}")',
1082-
f' {res_data} = series[trimmed_idx_data]',
1083-
]
1084-
results.append((col, res_data))
1085-
1086-
data = ', '.join(f'"{col}": {data}' for col, data in results)
1087-
func_lines += [f' return pandas.DataFrame({{{data}}}, index=self_index[trimmed_idx_data])']
1066+
def df_length_codelines(self):
1067+
"""Generate code lines to get length of DF"""
1068+
if self.columns:
1069+
return [' length = len(get_dataframe_data(self, 0))']
10881070

1089-
return func_lines
1071+
return [' length = 0']
10901072

10911073

1092-
def df_index_codelines(self):
1074+
def df_index_codelines(self, with_length=False):
10931075
"""Generate code lines to get or create index of DF"""
1076+
func_lines = []
10941077
if isinstance(self.index, types.NoneType):
1095-
func_lines = [' length = len(get_dataframe_data(self, 0))',
1096-
' _index = numpy.arange(length)',
1097-
' res_index = _index']
1078+
if with_length:
1079+
func_lines += df_length_codelines(self)
1080+
1081+
func_lines += [' res_index = numpy.arange(length)']
10981082
else:
1099-
func_lines = [' res_index = self._index']
1083+
func_lines += [' res_index = self._index']
11001084

11011085
return func_lines
11021086

11031087

1104-
def df_getitem_key_error_codelines():
1105-
"""Generate code lines to raise KeyError"""
1106-
return [' raise KeyError("Column is not in the DataFrame")']
1107-
1108-
11091088
def df_getitem_slice_idx_main_codelines(self, idx):
11101089
"""Generate main code lines for df.getitem with idx of slice"""
11111090
results = []
1112-
func_lines = df_index_codelines(self)
1091+
func_lines = df_index_codelines(self, with_length=True)
11131092
for i, col in enumerate(self.columns):
11141093
res_data = f'res_data_{i}'
11151094
func_lines += [
@@ -1127,7 +1106,7 @@ def df_getitem_slice_idx_main_codelines(self, idx):
11271106
def df_getitem_tuple_idx_main_codelines(self, literal_idx):
11281107
"""Generate main code lines for df.getitem with idx of tuple"""
11291108
results = []
1130-
func_lines = df_index_codelines(self)
1109+
func_lines = df_index_codelines(self, with_length=True)
11311110
needed_cols = {col: i for i, col in enumerate(self.columns) if col in literal_idx}
11321111
for col, i in needed_cols.items():
11331112
res_data = f'res_data_{i}'
@@ -1143,33 +1122,53 @@ def df_getitem_tuple_idx_main_codelines(self, literal_idx):
11431122
return func_lines
11441123

11451124

1146-
def df_getitem_bool_series_codegen(self, idx):
1147-
"""
1148-
Example of generated implementation with provided index:
1149-
def _df_getitem_bool_series_idx_impl(self, idx):
1150-
self_length = len(get_dataframe_data(self, 0))
1151-
trimmed_idx_data = idx._data[:self_length]
1152-
self_index = self._index
1153-
data_0 = get_dataframe_data(self, 0)
1154-
series = pandas.Series(data_0, index=self_index, name="A")
1155-
res_data_0 = series[trimmed_idx_data]
1156-
data_1 = get_dataframe_data(self, 1)
1157-
series = pandas.Series(data_1, index=self_index, name="B")
1158-
res_data_1 = series[trimmed_idx_data]
1159-
return pandas.DataFrame({"A": res_data_0, "B": res_data_1}, index=self_index[trimmed_idx_data])
1160-
"""
1161-
func_lines = ['def _df_getitem_bool_series_idx_impl(self, idx):']
1162-
if self.columns:
1163-
func_lines += df_getitem_bool_series_idx_main_codelines(self, idx)
1164-
else:
1165-
# raise KeyError if input DF is empty
1166-
func_lines += df_getitem_key_error_codelines()
1125+
def df_getitem_bool_series_idx_main_codelines(self, idx):
1126+
"""Generate main code lines for df.getitem"""
1127+
func_lines = df_length_codelines(self)
1128+
func_lines += [' _idx_data = idx._data[:length]']
1129+
func_lines += df_index_codelines(self)
11671130

1168-
func_text = '\n'.join(func_lines)
1169-
global_vars = {'pandas': pandas, 'numpy': numpy,
1170-
'get_dataframe_data': get_dataframe_data}
1131+
results = []
1132+
for i, col in enumerate(self.columns):
1133+
res_data = f'res_data_{i}'
1134+
func_lines += [
1135+
f' data_{i} = get_dataframe_data(self, {i})',
1136+
f' series_{i} = pandas.Series(data_{i}, index=res_index, name="{col}")',
1137+
f' {res_data} = series_{i}[_idx_data]'
1138+
]
1139+
results.append((col, res_data))
1140+
1141+
data = ', '.join(f'"{col}": {data}' for col, data in results)
1142+
func_lines += [f' return pandas.DataFrame({{{data}}}, index=res_index[_idx_data])']
1143+
1144+
return func_lines
11711145

1172-
return func_text, global_vars
1146+
1147+
def df_getitem_bool_array_idx_main_codelines(self, idx):
1148+
"""Generate main code lines for df.getitem"""
1149+
func_lines = df_length_codelines(self)
1150+
func_lines += [' if length != len(idx):',
1151+
' raise ValueError("Item wrong length.")']
1152+
func_lines += df_index_codelines(self)
1153+
1154+
results = []
1155+
for i, col in enumerate(self.columns):
1156+
res_data = f'res_data_{i}'
1157+
func_lines += [
1158+
f' data_{i} = get_dataframe_data(self, {i})',
1159+
f' {res_data} = pandas.Series(data_{i}[idx], index=res_index[idx], name="{col}")'
1160+
]
1161+
results.append((col, res_data))
1162+
1163+
data = ', '.join(f'"{col}": {data}' for col, data in results)
1164+
func_lines += [f' return pandas.DataFrame({{{data}}}, index=res_index[idx])']
1165+
1166+
return func_lines
1167+
1168+
1169+
def df_getitem_key_error_codelines():
1170+
"""Generate code lines to raise KeyError"""
1171+
return [' raise KeyError("Column is not in the DataFrame")']
11731172

11741173

11751174
def df_getitem_slice_idx_codegen(self, idx):
@@ -1225,12 +1224,61 @@ def _df_getitem_tuple_idx_impl(self, idx)
12251224
return func_text, global_vars
12261225

12271226

1227+
def df_getitem_bool_series_idx_codegen(self, idx):
1228+
"""
1229+
Example of generated implementation with provided index:
1230+
def _df_getitem_bool_series_idx_impl(self, idx):
1231+
length = len(get_dataframe_data(self, 0))
1232+
_idx_data = idx._data[:length]
1233+
res_index = self._index
1234+
data_0 = get_dataframe_data(self, 0)
1235+
series_0 = pandas.Series(data_0, index=res_index, name="A")
1236+
res_data_0 = series_0[_idx_data]
1237+
data_1 = get_dataframe_data(self, 1)
1238+
series_1 = pandas.Series(data_1, index=res_index, name="B")
1239+
res_data_1 = series_1[_idx_data]
1240+
return pandas.DataFrame({"A": res_data_0, "B": res_data_1}, index=res_index[_idx_data])
1241+
"""
1242+
func_lines = ['def _df_getitem_bool_series_idx_impl(self, idx):']
1243+
func_lines += df_getitem_bool_series_idx_main_codelines(self, idx)
1244+
func_text = '\n'.join(func_lines)
1245+
global_vars = {'pandas': pandas, 'numpy': numpy,
1246+
'get_dataframe_data': get_dataframe_data}
1247+
1248+
return func_text, global_vars
1249+
1250+
1251+
def df_getitem_bool_array_idx_codegen(self, idx):
1252+
"""
1253+
Example of generated implementation with provided index:
1254+
def _df_getitem_bool_array_idx_impl(self, idx):
1255+
length = len(get_dataframe_data(self, 0))
1256+
if length != len(idx):
1257+
raise ValueError("Item wrong length.")
1258+
res_index = numpy.arange(length)
1259+
data_0 = get_dataframe_data(self, 0)
1260+
res_data_0 = pandas.Series(data_0[idx], index=res_index[idx], name="A")
1261+
data_1 = get_dataframe_data(self, 1)
1262+
res_data_1 = pandas.Series(data_1[idx], index=res_index[idx], name="B")
1263+
return pandas.DataFrame({"A": res_data_0, "B": res_data_1}, index=res_index[idx])
1264+
"""
1265+
func_lines = ['def _df_getitem_bool_array_idx_impl(self, idx):']
1266+
func_lines += df_getitem_bool_array_idx_main_codelines(self, idx)
1267+
func_text = '\n'.join(func_lines)
1268+
global_vars = {'pandas': pandas, 'numpy': numpy,
1269+
'get_dataframe_data': get_dataframe_data}
1270+
1271+
return func_text, global_vars
1272+
1273+
12281274
gen_df_getitem_slice_idx_impl = gen_df_impl_generator(
12291275
df_getitem_slice_idx_codegen, '_df_getitem_slice_idx_impl')
12301276
gen_df_getitem_tuple_idx_impl = gen_df_impl_generator(
12311277
df_getitem_tuple_idx_codegen, '_df_getitem_tuple_idx_impl')
12321278
gen_df_getitem_bool_series_idx_impl = gen_df_impl_generator(
1233-
df_getitem_bool_series_codegen, '_df_getitem_bool_series_idx_impl')
1279+
df_getitem_bool_series_idx_codegen, '_df_getitem_bool_series_idx_impl')
1280+
gen_df_getitem_bool_array_idx_impl = gen_df_impl_generator(
1281+
df_getitem_bool_array_idx_codegen, '_df_getitem_bool_array_idx_impl')
12341282

12351283

12361284
@sdc_overload(operator.getitem)
@@ -1289,7 +1337,12 @@ def _df_getitem_unicode_idx_impl(self, idx):
12891337

12901338
return gen_df_getitem_bool_series_idx_impl(self, idx)
12911339

1292-
ty_checker.raise_exc(idx, 'str', 'idx')
1340+
if isinstance(idx, types.Array) and isinstance(idx.dtype, types.Boolean):
1341+
return gen_df_getitem_bool_array_idx_impl(self, idx)
1342+
1343+
ty_checker = TypeChecker('Operator getitem().')
1344+
expected_types = 'str, tuple(str), slice, series(bool), array(bool)'
1345+
ty_checker.raise_exc(idx, expected_types, 'idx')
12931346

12941347

12951348
@sdc_overload_method(DataFrameType, 'pct_change')

sdc/tests/test_dataframe.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,15 @@ def test_impl(df, series):
12771277
sdc_func = self.jit(test_impl)
12781278
pd.testing.assert_frame_equal(sdc_func(df, s), test_impl(df, s))
12791279

1280+
def _test_df_getitem_bool_array_even_idx(self, df):
1281+
def test_impl(df, arr):
1282+
return df[arr]
1283+
1284+
arr = np.array([i % 2 for i in range(len(df))], dtype=np.bool_)
1285+
1286+
sdc_func = self.jit(test_impl)
1287+
pd.testing.assert_frame_equal(sdc_func(df, arr), test_impl(df, arr))
1288+
12801289
@skip_sdc_jit('DF.getitem unsupported exceptions')
12811290
def test_df_getitem_str_literal_idx_exception_key_error(self):
12821291
def test_impl(df):
@@ -1301,7 +1310,7 @@ def test_impl(df, idx):
13011310
with self.assertRaises(KeyError):
13021311
sdc_func(df, 'ABC')
13031312

1304-
@skip_sdc_jit('DF.getitem unsupported Series name')
1313+
@skip_sdc_jit('DF.getitem unsupported exceptions')
13051314
def test_df_getitem_tuple_idx_exception_key_error(self):
13061315
sdc_func = self.jit(lambda df: df[('A', 'Z')])
13071316

@@ -1310,6 +1319,18 @@ def test_df_getitem_tuple_idx_exception_key_error(self):
13101319
with self.assertRaises(KeyError):
13111320
sdc_func(df)
13121321

1322+
@skip_sdc_jit('DF.getitem unsupported exceptions')
1323+
def test_df_getitem_bool_array_idx_exception_value_error(self):
1324+
sdc_func = self.jit(lambda df, arr: df[arr])
1325+
1326+
for df in [gen_df(test_global_input_data_float64), pd.DataFrame()]:
1327+
arr = np.array([i % 2 for i in range(len(df) + 1)], dtype=np.bool_)
1328+
with self.subTest(df=df, arr=arr):
1329+
with self.assertRaises(ValueError) as raises:
1330+
sdc_func(df, arr)
1331+
self.assertIn('Item wrong length', str(raises.exception))
1332+
1333+
13131334
@skip_sdc_jit('DF.getitem unsupported Series name')
13141335
def test_df_getitem_idx(self):
13151336
dfs = [gen_df(test_global_input_data_float64),
@@ -1330,6 +1351,7 @@ def test_df_getitem_idx_no_index(self):
13301351
for df in dfs:
13311352
with self.subTest(df=df):
13321353
self._test_df_getitem_bool_series_even_idx(df)
1354+
self._test_df_getitem_bool_array_even_idx(df)
13331355

13341356
@skip_sdc_jit('DF.getitem unsupported Series name')
13351357
def test_df_getitem_idx_multiple_types(self):
@@ -1345,6 +1367,7 @@ def test_df_getitem_idx_multiple_types(self):
13451367
self._test_df_getitem_unbox_slice_idx(df, 1, 3)
13461368
self._test_df_getitem_tuple_idx(df)
13471369
self._test_df_getitem_bool_series_even_idx(df)
1370+
self._test_df_getitem_bool_array_even_idx(df)
13481371

13491372
@unittest.skip('DF.getitem df[bool_series] unsupported index')
13501373
def test_df_getitem_bool_series_even_idx_with_index(self):

0 commit comments

Comments
 (0)