Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 632b554

Browse files
authored
Optimize series.rolling.sum() (#608)
Optimize series.rolling.sum()
1 parent a98cbeb commit 632b554

File tree

4 files changed

+135
-21
lines changed

4 files changed

+135
-21
lines changed

sdc/datatypes/hpat_pandas_series_rolling_functions.py

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sdc.datatypes.common_functions import _sdc_pandas_series_align
3838
from sdc.datatypes.hpat_pandas_series_rolling_types import SeriesRollingType
3939
from sdc.hiframes.pd_series_type import SeriesType
40+
from sdc.utilities.prange_utils import parallel_chunks
4041
from sdc.utilities.sdc_typing_utils import TypeChecker
4142
from sdc.utilities.utils import sdc_overload_method, sdc_register_jitable
4243

@@ -214,12 +215,6 @@ def arr_std(arr, ddof):
214215
return arr_var(arr, ddof) ** 0.5
215216

216217

217-
@sdc_register_jitable
218-
def arr_sum(arr):
219-
"""Calculate sum of values"""
220-
return arr.sum()
221-
222-
223218
@sdc_register_jitable
224219
def arr_var(arr, ddof):
225220
"""Calculate unbiased variance of values"""
@@ -308,12 +303,87 @@ def apply_minp(arr, ddof, minp):
308303
gen_hpat_pandas_series_rolling_impl(arr_skew))
309304
hpat_pandas_rolling_series_std_impl = register_jitable(
310305
gen_hpat_pandas_series_rolling_ddof_impl(arr_std))
311-
hpat_pandas_rolling_series_sum_impl = register_jitable(
312-
gen_hpat_pandas_series_rolling_impl(arr_sum))
313306
hpat_pandas_rolling_series_var_impl = register_jitable(
314307
gen_hpat_pandas_series_rolling_ddof_impl(arr_var))
315308

316309

310+
@sdc_register_jitable
311+
def pop_sum(value, nfinite, result):
312+
"""Calculate the window sum without old value."""
313+
if numpy.isfinite(value):
314+
nfinite -= 1
315+
result -= value
316+
317+
return nfinite, result
318+
319+
320+
@sdc_register_jitable
321+
def put_sum(value, nfinite, result):
322+
"""Calculate the window sum with new value."""
323+
if numpy.isfinite(value):
324+
nfinite += 1
325+
result += value
326+
327+
return nfinite, result
328+
329+
330+
@sdc_register_jitable
331+
def result_or_nan(nfinite, minp, result):
332+
"""Get result taking into account min periods."""
333+
if nfinite < minp:
334+
return numpy.nan
335+
336+
return result
337+
338+
339+
def gen_sdc_pandas_series_rolling_impl(pop, put, init_result=numpy.nan):
340+
"""Generate series rolling methods implementations based on pop/put funcs"""
341+
def impl(self):
342+
win = self._window
343+
minp = self._min_periods
344+
345+
input_series = self._data
346+
input_arr = input_series._data
347+
length = len(input_arr)
348+
output_arr = numpy.empty(length, dtype=float64)
349+
350+
chunks = parallel_chunks(length)
351+
for i in prange(len(chunks)):
352+
chunk = chunks[i]
353+
nfinite = 0
354+
result = init_result
355+
356+
prelude_start = max(0, chunk.start - win + 1)
357+
prelude_stop = min(chunk.start, prelude_start + win)
358+
359+
interlude_start = prelude_stop
360+
interlude_stop = min(prelude_start + win, chunk.stop)
361+
362+
for idx in range(prelude_start, prelude_stop):
363+
value = input_arr[idx]
364+
nfinite, result = put(value, nfinite, result)
365+
366+
for idx in range(interlude_start, interlude_stop):
367+
value = input_arr[idx]
368+
nfinite, result = put(value, nfinite, result)
369+
output_arr[idx] = result_or_nan(nfinite, minp, result)
370+
371+
for idx in range(interlude_stop, chunk.stop):
372+
put_value = input_arr[idx]
373+
pop_value = input_arr[idx - win]
374+
nfinite, result = put(put_value, nfinite, result)
375+
nfinite, result = pop(pop_value, nfinite, result)
376+
output_arr[idx] = result_or_nan(nfinite, minp, result)
377+
378+
return pandas.Series(output_arr, input_series._index,
379+
name=input_series._name)
380+
return impl
381+
382+
383+
sdc_pandas_series_rolling_sum_impl = register_jitable(
384+
gen_sdc_pandas_series_rolling_impl(pop_sum, put_sum, init_result=0.))
385+
386+
317387
@sdc_rolling_overload(SeriesRollingType, 'apply')
318388
def hpat_pandas_series_rolling_apply(self, func, raw=None):
319389

@@ -619,13 +689,13 @@ def hpat_pandas_series_rolling_std(self, ddof=1):
619689
return hpat_pandas_rolling_series_std_impl
620690

621691

622-
@sdc_rolling_overload(SeriesRollingType, 'sum')
692+
@sdc_overload_method(SeriesRollingType, 'sum')
623693
def hpat_pandas_series_rolling_sum(self):
624694

625695
ty_checker = TypeChecker('Method rolling.sum().')
626696
ty_checker.check(self, SeriesRollingType)
627697

628-
return hpat_pandas_rolling_series_sum_impl
698+
return sdc_pandas_series_rolling_sum_impl
629699

630700

631701
@sdc_rolling_overload(SeriesRollingType, 'var')

sdc/tests/test_rolling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,8 @@ def test_impl(obj, window, min_periods):
847847
hpat_func = self.jit(test_impl)
848848
assert_equal = self._get_assert_equal(obj)
849849

850-
for window in range(0, len(obj) + 3, 2):
851-
for min_periods in range(0, window + 1, 2):
850+
for window in range(len(obj) + 2):
851+
for min_periods in range(window):
852852
with self.subTest(obj=obj, window=window,
853853
min_periods=min_periods):
854854
jit_result = hpat_func(obj, window, min_periods)

sdc/tests/tests_perf/test_perf_series_rolling.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,29 @@
2424
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
2525
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
# *****************************************************************************
27-
import string
27+
2828
import time
2929

30-
import numba
3130
import pandas
3231
import numpy as np
3332

3433
from sdc.tests.test_utils import test_global_input_data_float64
3534
from sdc.tests.tests_perf.test_perf_base import TestBase
36-
from sdc.tests.tests_perf.test_perf_utils import (calc_compilation, get_times,
37-
perf_data_gen_fixed_len)
35+
from sdc.tests.tests_perf.test_perf_utils import perf_data_gen_fixed_len
3836
from .generator import generate_test_cases
3937
from .generator import TestCase as TC
4038

4139

40+
rolling_usecase_tmpl = """
41+
def series_rolling_{method_name}_usecase(data, {extra_usecase_params}):
42+
start_time = time.time()
43+
for i in range({ncalls}):
44+
res = data.rolling({rolling_params}).{method_name}({method_params})
45+
end_time = time.time()
46+
return end_time - start_time, res
47+
"""
48+
49+
4250
def get_rolling_params(window=100, min_periods=None):
4351
"""Generate supported rolling parameters"""
4452
rolling_params = [f'{window}']
@@ -48,14 +56,37 @@ def get_rolling_params(window=100, min_periods=None):
4856
return ', '.join(rolling_params)
4957

5058

59+
def gen_series_rolling_usecase(method_name, rolling_params=None,
60+
extra_usecase_params='',
61+
method_params='', ncalls=1):
62+
"""Generate series rolling method use case"""
63+
if not rolling_params:
64+
rolling_params = get_rolling_params()
65+
66+
func_text = rolling_usecase_tmpl.format(**{
67+
'method_name': method_name,
68+
'extra_usecase_params': extra_usecase_params,
69+
'rolling_params': rolling_params,
70+
'method_params': method_params,
71+
'ncalls': ncalls
72+
})
73+
74+
global_vars = {'np': np, 'time': time}
75+
loc_vars = {}
76+
exec(func_text, global_vars, loc_vars)
77+
_series_rolling_usecase = loc_vars[f'series_rolling_{method_name}_usecase']
78+
79+
return _series_rolling_usecase
80+
81+
5182
# python -m sdc.runtests sdc.tests.tests_perf.test_perf_series_rolling.TestSeriesRollingMethods
5283
class TestSeriesRollingMethods(TestBase):
53-
# more than 19 columns raise SystemError: CPUDispatcher() returned a result with an error set
54-
max_columns_num = 19
55-
5684
@classmethod
5785
def setUpClass(cls):
5886
super().setUpClass()
87+
cls.map_ncalls_dlength = {
88+
'sum': (100, [8 * 10 ** 5]),
89+
}
5990

6091
def _test_case(self, pyfunc, name, total_data_length, data_num=1,
6192
input_data=test_global_input_data_float64):
@@ -82,6 +113,20 @@ def _test_case(self, pyfunc, name, total_data_length, data_num=1,
82113
self._test_jit(pyfunc, base, *args)
83114
self._test_py(pyfunc, base, *args)
84115

116+
def _test_series_rolling_method(self, name, rolling_params=None,
117+
extra_usecase_params='', method_params=''):
118+
ncalls, total_data_length = self.map_ncalls_dlength[name]
119+
usecase = gen_series_rolling_usecase(name, rolling_params=rolling_params,
120+
extra_usecase_params=extra_usecase_params,
121+
method_params=method_params, ncalls=ncalls)
122+
data_num = 1
123+
if extra_usecase_params:
124+
data_num += len(extra_usecase_params.split(', '))
125+
self._test_case(usecase, name, total_data_length, data_num=data_num)
126+
127+
def test_series_rolling_sum(self):
128+
self._test_series_rolling_method('sum')
129+
85130

86131
cases = [
87132
TC(name='apply', size=[10 ** 7], params='func=lambda x: np.nan if len(x) == 0 else x.mean()'),
@@ -96,7 +141,6 @@ def _test_case(self, pyfunc, name, total_data_length, data_num=1,
96141
TC(name='quantile', size=[10 ** 7], params='0.2'),
97142
TC(name='skew', size=[10 ** 7]),
98143
TC(name='std', size=[10 ** 7]),
99-
TC(name='sum', size=[10 ** 7]),
100144
TC(name='var', size=[10 ** 7]),
101145
]
102146

sdc/utilities/prange_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import sdc
3030

3131
from typing import NamedTuple
32-
from sdc.utilities.utils import sdc_overload, sdc_register_jitable
32+
from sdc.utilities.utils import sdc_register_jitable
3333

3434

3535
class Chunk(NamedTuple):

0 commit comments

Comments
 (0)