Skip to content

Commit 429ba6c

Browse files
committed
Fix numba impl of CumOp
1 parent 98be9c5 commit 429ba6c

File tree

1 file changed

+43
-17
lines changed

1 file changed

+43
-17
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,31 +36,57 @@ def numba_funcify_CumOp(op, node, **kwargs):
3636
mode = op.mode
3737
ndim = node.outputs[0].ndim
3838

39+
if axis < 0:
40+
axis = ndim + axis
41+
if axis < 0 or axis >= ndim:
42+
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
43+
3944
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
4045

4146
if mode == "add":
42-
np_func = np.add
43-
identity = 0
47+
48+
if ndim == 1:
49+
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
50+
def cumop(x):
51+
return np.cumsum(x)
52+
53+
else:
54+
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
55+
def cumop(x):
56+
out_dtype = x.dtype
57+
if x.shape[axis] < 2:
58+
return x.astype(out_dtype)
59+
60+
x_axis_first = x.transpose(reaxis_first)
61+
res = np.empty(x_axis_first.shape, dtype=out_dtype)
62+
63+
res[0] = x_axis_first[0]
64+
for m in range(1, x.shape[axis]):
65+
res[m] = res[m - 1] + x_axis_first[m]
66+
67+
return res.transpose(reaxis_first)
68+
4469
else:
45-
np_func = np.multiply
46-
identity = 1
70+
if ndim == 1:
71+
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
72+
def cumop(x):
73+
return np.cumprod(x)
4774

48-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
49-
def cumop(x):
50-
out_dtype = x.dtype
51-
if x.shape[axis] < 2:
52-
return x.astype(out_dtype)
75+
else:
76+
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
77+
def cumop(x):
78+
out_dtype = x.dtype
79+
if x.shape[axis] < 2:
80+
return x.astype(out_dtype)
5381

54-
x_axis_first = x.transpose(reaxis_first)
55-
res = np.empty(x_axis_first.shape, dtype=out_dtype)
82+
x_axis_first = x.transpose(reaxis_first)
83+
res = np.empty(x_axis_first.shape, dtype=out_dtype)
5684

57-
for m in numba.prange(x.shape[axis]):
58-
if m == 0:
59-
np_func(identity, x_axis_first[m], res[m])
60-
else:
61-
np_func(res[m - 1], x_axis_first[m], res[m])
85+
res[0] = x_axis_first[0]
86+
for m in range(1, x.shape[axis]):
87+
res[m] = res[m - 1] * x_axis_first[m]
6288

63-
return res.transpose(reaxis_first)
89+
return res.transpose(reaxis_first)
6490

6591
return cumop
6692

0 commit comments

Comments
 (0)