Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit cfbebc6

Browse files
Fix for DataFrame str column created as List of Optional values (#582)
* Fix for DataFrame str column created as List of Optional values * Applying review comments * Fixing more remarks
1 parent 0a1b051 commit cfbebc6

File tree

5 files changed

+93
-49
lines changed

5 files changed

+93
-49
lines changed

sdc/hiframes/api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,9 @@ def generic(self, args, kws):
11851185
and (isinstance(column.dtype, types.Number)
11861186
or column.dtype == types.boolean)):
11871187
ret_typ = types.Array(column.dtype, 1, 'C')
1188-
if isinstance(column, types.List) and column.dtype == string_type:
1188+
if (isinstance(column, types.List)
1189+
and (column.dtype == string_type
1190+
or isinstance(column.dtype, types.Optional) and column.dtype.type == string_type)):
11891191
ret_typ = string_array_type
11901192
if isinstance(column, DatetimeIndexType):
11911193
ret_typ = sdc.hiframes.pd_index_ext._dt_index_data_typ
@@ -1214,7 +1216,10 @@ def fix_df_array_list_impl(column): # pragma: no cover
12141216
return fix_df_array_list_impl
12151217

12161218
# convert list of strings to string array
1217-
if isinstance(column, types.List) and column.dtype == string_type:
1219+
if (isinstance(column, types.List)
1220+
and (column.dtype == string_type
1221+
or isinstance(column.dtype, types.Optional) and column.dtype.type == string_type)):
1222+
12181223
def fix_df_array_str_impl(column): # pragma: no cover
12191224
return sdc.str_arr_ext.StringArray(column)
12201225
return fix_df_array_str_impl

sdc/str_arr_ext.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -670,33 +670,43 @@ def construct_string_array(context, builder):
670670
@lower_builtin(StringArray, types.UniTuple)
671671
@lower_builtin(StringArray, types.Tuple)
672672
def impl_string_array_single(context, builder, sig, args):
673-
if isinstance(args[0], types.UniTuple):
674-
assert args[0].dtype == string_type
675673

676-
if isinstance(args[0], types.Tuple):
677-
for i in args[0]:
674+
arg = args[0]
675+
if isinstance(arg, (types.UniTuple, types.List)):
676+
assert (arg.dtype == string_type
677+
or (isinstance(arg.dtype, types.Optional) and arg.dtype.type == string_type))
678+
679+
# FIXME: doesn't work for Tuple with None values
680+
if isinstance(arg, types.Tuple):
681+
for i in arg:
678682
assert i.dtype == string_type or i.dtype == types.StringLiteral
679683

680684
if not sig.args: # return empty string array if no args
681685
res = context.compile_internal(
682686
builder, lambda: pre_alloc_string_array(0, 0), sig, args)
683687
return res
684688

685-
def str_arr_from_list(in_list):
689+
def str_arr_from_sequence(in_list):
686690
n_strs = len(in_list)
687691
total_chars = 0
688692
# TODO: use vector to avoid two passes?
689693
# get total number of chars
690-
for s in in_list:
691-
total_chars += get_utf8_size(s)
694+
nan_mask = np.zeros(n_strs, dtype=np.bool_)
695+
for i in numba.prange(n_strs):
696+
s = in_list[i]
697+
if s is None:
698+
nan_mask[i] = True
699+
else:
700+
total_chars += get_utf8_size(s)
692701

693702
A = pre_alloc_string_array(n_strs, total_chars)
694-
for i in range(n_strs):
695-
A[i] = in_list[i]
703+
for i in np.arange(n_strs):
704+
A[i] = '' if nan_mask[i] else in_list[i]
705+
str_arr_set_na_by_mask(A, nan_mask)
696706

697707
return A
698708

699-
res = context.compile_internal(builder, str_arr_from_list, sig, args)
709+
res = context.compile_internal(builder, str_arr_from_sequence, sig, args)
700710
return res
701711

702712
# @lower_builtin(StringArray)

sdc/tests/test_dataframe.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import unittest
3434
from itertools import permutations, product
3535
from numba.config import IS_32BITS
36+
from numba.special import literal_unroll
3637

3738
import sdc
3839
from sdc.tests.gen_test_data import ParquetGenerator
@@ -1726,6 +1727,47 @@ def test_impl():
17261727
self.assertTrue(isinstance(two, np.ndarray))
17271728
self.assertTrue(isinstance(three, np.ndarray))
17281729

1730+
def test_df_iterate_over_columns1(self):
1731+
""" Verifies iteration over df columns using literal tuple of column indices. """
1732+
from sdc.hiframes.pd_dataframe_ext import get_dataframe_data
1733+
from sdc.hiframes.api import get_nan_mask
1734+
1735+
@self.jit
1736+
def jitted_func():
1737+
df = pd.DataFrame({
1738+
'A': ['a', 'b', None, 'a', '', None, 'b'],
1739+
'B': ['a', 'b', 'd', 'a', '', 'c', 'b'],
1740+
'C': [np.nan, 1, 2, 1, np.nan, 2, 1],
1741+
'D': [1, 2, 9, 5, 2, 1, 0]
1742+
})
1743+
1744+
# tuple of literals has to be created in a jitted function, otherwise
1745+
# col_id won't be literal and unboxing in get_dataframe_data won't compile
1746+
column_ids = (0, 1, 2, 3)
1747+
res_nan_mask = np.zeros(len(df), dtype=np.bool_)
1748+
for col_id in literal_unroll(column_ids):
1749+
res_nan_mask += get_nan_mask(get_dataframe_data(df, col_id))
1750+
return res_nan_mask
1751+
1752+
# expected is a boolean mask of df rows that have None values
1753+
expected = np.asarray([True, False, True, False, True, True, False])
1754+
result = jitted_func()
1755+
np.testing.assert_array_equal(result, expected)
1756+
1757+
def test_df_create_str_with_none(self):
1758+
""" Verifies creation of a dataframe with a string column from a list of Optional values. """
1759+
def test_impl():
1760+
df = pd.DataFrame({
1761+
'A': ['a', 'b', None, 'a', '', None, 'b'],
1762+
'B': ['a', 'b', 'd', 'a', '', 'c', 'b'],
1763+
'C': [np.nan, 1, 2, 1, np.nan, 2, 1]
1764+
})
1765+
1766+
return df['A'].isna()
1767+
hpat_func = self.jit(test_impl)
1768+
1769+
pd.testing.assert_series_equal(hpat_func(), test_impl())
1770+
17291771

17301772
if __name__ == "__main__":
17311773
unittest.main()

sdc/tests/test_series.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6291,14 +6291,13 @@ def test_impl(A, i, value):
62916291
for series_data in all_data:
62926292
for series_index in indexes:
62936293
S = pd.Series(series_data, series_index, dtype=dtype)
6294-
for idx in idxs:
6295-
for value in values:
6296-
with self.subTest(series=S, idx=idx, value=value):
6297-
S1 = S.copy(deep=True)
6298-
S2 = S.copy(deep=True)
6299-
hpat_func(S1, idx, value)
6300-
test_impl(S2, idx, value)
6301-
pd.testing.assert_series_equal(S1, S2)
6294+
for idx, value in product(idxs, values):
6295+
with self.subTest(series=S, idx=idx, value=value):
6296+
S1 = S.copy(deep=True)
6297+
S2 = S.copy(deep=True)
6298+
hpat_func(S1, idx, value)
6299+
test_impl(S2, idx, value)
6300+
pd.testing.assert_series_equal(S1, S2)
63026301

63036302
@skip_sdc_jit('Not implemented in old-pipeline')
63046303
@skip_numba_jit('Requires StringArray support of operator.eq')
@@ -6365,11 +6364,10 @@ def test_series_setitem_idx_str_series(self):
63656364
integer Series with index of matching dtype and scalar and non scalar assigned values """
63666365

63676366
n, k = 11, 4
6368-
np.random.seed(0)
63696367
series_data = np.arange(n)
63706368
series_index = gen_strlist(n, 2, 'abcd123 ')
63716369

6372-
idx = create_series_from_values(k, series_index)
6370+
idx = create_series_from_values(k, series_index, seed=0)
63736371
assigned_values = -10 + np.arange(k) * (-1)
63746372
values_to_test = [-100,
63756373
np.array(assigned_values),
@@ -6382,11 +6380,10 @@ def test_series_setitem_idx_float_series(self):
63826380
integer Series with index of matching dtype and scalar and non scalar assigned values """
63836381

63846382
n, k = 11, 4
6385-
np.random.seed(0)
63866383
series_data = np.arange(n)
63876384
series_index = np.arange(n, dtype=np.float)
63886385

6389-
idx = create_series_from_values(k, series_index)
6386+
idx = create_series_from_values(k, series_index, seed=0)
63906387
assigned_values = -10 + np.arange(k) * (-1)
63916388
values_to_test = [
63926389
-100,
@@ -6404,11 +6401,10 @@ def test_impl(A, i, value):
64046401
hpat_func = self.jit(test_impl)
64056402

64066403
n, k = 11, 4
6407-
np.random.seed(0)
64086404
series_data = np.arange(n)
64096405
series_index = np.arange(n)
64106406

6411-
idx = create_series_from_values(k, series_index)
6407+
idx = create_series_from_values(k, series_index, seed=0)
64126408
assigned_values = -10 + np.arange(k) * (-1)
64136409
values_to_test = [-100,
64146410
np.array(assigned_values),
@@ -6421,11 +6417,10 @@ def test_series_setitem_idx_int_series2(self):
64216417
integer Series with index of non-matching dtype and scalar and non scalar assigned values """
64226418

64236419
n, k = 11, 4
6424-
np.random.seed(0)
64256420
series_data = np.arange(n)
64266421
series_index = gen_strlist(n, 2, 'abcd123 ')
64276422

6428-
idx = create_series_from_values(k, np.arange(n))
6423+
idx = create_series_from_values(k, np.arange(n), seed=0)
64296424
assigned_values = -10 + np.arange(k) * (-1)
64306425
values_to_test = [-100,
64316426
np.array(assigned_values),
@@ -6476,12 +6471,11 @@ def test_series_setitem_idx_int_array1(self):
64766471
integer Series with integer index and scalar and non scalar assigned values """
64776472

64786473
n, k = 11, 4
6479-
np.random.seed(0)
6480-
64816474
series_data = np.arange(n)
64826475
series_index = np.arange(n)
64836476

6484-
idx = take_k_elements(k, series_index)
6477+
np.random.seed(0)
6478+
idx = take_k_elements(k, series_index, seed=0)
64856479
assigned_values = -10 + np.arange(k) * (-1)
64866480
values_to_test = [
64876481
-100,
@@ -6496,11 +6490,10 @@ def test_series_setitem_idx_int_array2(self):
64966490
integer Series with string index and scalar and non scalar assigned values """
64976491

64986492
n, k = 11, 4
6499-
np.random.seed(0)
65006493
series_data = np.arange(n)
65016494
series_index = gen_strlist(n, 2, 'abcd123 ')
65026495

6503-
idx = take_k_elements(k, np.arange(n))
6496+
idx = take_k_elements(k, np.arange(n), seed=0)
65046497
assigned_values = -10 + np.arange(k) * (-1)
65056498
values_to_test = [
65066499
-100,

sdc/tests/test_utils.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -212,28 +212,22 @@ def skip_inline(msg_or_func):
212212
return wrapper(func) if func else wrapper
213213

214214

215-
def take_k_elements(k, data):
216-
random_idx = np.arange(len(data))
217-
np.random.shuffle(random_idx)
218-
return np.asarray(data).take(random_idx[:k])
215+
def take_k_elements(k, data, repeat=False, seed=None):
216+
if seed is not None:
217+
np.random.seed(seed)
218+
return np.random.choice(np.asarray(data), k, replace=repeat)
219219

220220

221-
def create_series_from_values(size, data_values, index_values=None, name=None, unique=True):
221+
def create_series_from_values(size, data_values, index_values=None, name=None, unique=True, seed=None):
222+
if seed is not None:
223+
np.random.seed(seed)
222224

223225
min_size = min(size, len(data_values))
224226
if index_values:
225227
min_size = min(min_size, len(index_values))
228+
repeat = False if unique and min_size == size else True
226229

227-
if unique and min_size == size:
228-
series_data = take_k_elements(size, data_values)
229-
series_index = take_k_elements(size, index_values) if index_values else None
230-
else:
231-
data_values_pos = np.random.randint(0, len(data_values), size)
232-
series_data = np.asarray(data_values).take(data_values_pos)
233-
if index_values:
234-
index_values_pos = np.random.randint(0, len(index_values), size)
235-
series_index = np.asarray(index_values).take(index_values_pos)
236-
else:
237-
series_index = None
230+
series_data = take_k_elements(size, data_values, repeat)
231+
series_index = take_k_elements(size, index_values, repeat) if index_values else None
238232

239233
return pandas.Series(series_data, series_index, name)

0 commit comments

Comments
 (0)