@@ -36,31 +36,57 @@ def numba_funcify_CumOp(op, node, **kwargs):
3636 mode = op .mode
3737 ndim = node .outputs [0 ].ndim
3838
39+ if axis < 0 :
40+ axis = ndim + axis
41+ if axis < 0 or axis >= ndim :
42+ raise ValueError (f"Invalid axis { axis } for array with ndim { ndim } " )
43+
3944 reaxis_first = (axis ,) + tuple (i for i in range (ndim ) if i != axis )
4045
4146 if mode == "add" :
42- np_func = np .add
43- identity = 0
47+
48+ if ndim == 1 :
49+ @numba_basic .numba_njit (fastmath = config .numba__fastmath )
50+ def cumop (x ):
51+ return np .cumsum (x )
52+
53+ else :
54+ @numba_basic .numba_njit (boundscheck = False , fastmath = config .numba__fastmath )
55+ def cumop (x ):
56+ out_dtype = x .dtype
57+ if x .shape [axis ] < 2 :
58+ return x .astype (out_dtype )
59+
60+ x_axis_first = x .transpose (reaxis_first )
61+ res = np .empty (x_axis_first .shape , dtype = out_dtype )
62+
63+ res [0 ] = x_axis_first [0 ]
64+ for m in range (1 , x .shape [axis ]):
65+ res [m ] = res [m - 1 ] + x_axis_first [m ]
66+
67+ return res .transpose (reaxis_first )
68+
4469 else :
45- np_func = np .multiply
46- identity = 1
70+ if ndim == 1 :
71+ @numba_basic .numba_njit (fastmath = config .numba__fastmath )
72+ def cumop (x ):
73+ return np .cumprod (x )
4774
48- @numba_basic .numba_njit (boundscheck = False , fastmath = config .numba__fastmath )
49- def cumop (x ):
50- out_dtype = x .dtype
51- if x .shape [axis ] < 2 :
52- return x .astype (out_dtype )
75+ else :
76+ @numba_basic .numba_njit (boundscheck = False , fastmath = config .numba__fastmath )
77+ def cumop (x ):
78+ out_dtype = x .dtype
79+ if x .shape [axis ] < 2 :
80+ return x .astype (out_dtype )
5381
54- x_axis_first = x .transpose (reaxis_first )
55- res = np .empty (x_axis_first .shape , dtype = out_dtype )
82+ x_axis_first = x .transpose (reaxis_first )
83+ res = np .empty (x_axis_first .shape , dtype = out_dtype )
5684
57- for m in numba .prange (x .shape [axis ]):
58- if m == 0 :
59- np_func (identity , x_axis_first [m ], res [m ])
60- else :
61- np_func (res [m - 1 ], x_axis_first [m ], res [m ])
85+ res [0 ] = x_axis_first [0 ]
86+ for m in range (1 , x .shape [axis ]):
87+ res [m ] = res [m - 1 ] * x_axis_first [m ]
6288
63- return res .transpose (reaxis_first )
89+ return res .transpose (reaxis_first )
6490
6591 return cumop
6692
0 commit comments