Skip to content

Commit 0933d20

Browse files
committed
Normalize negative axes
1 parent ee9a6ff commit 0933d20

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,18 @@ def create_vectorize_func(
164164
return elemwise_fn
165165

166166

167+
def normalize_axis(axis, ndim):
168+
if axis is None:
169+
return axis
170+
171+
if axis < 0:
172+
axis = ndim + axis
173+
174+
if axis < 0 or axis >= ndim:
175+
raise np.AxisError(ndim=ndim, axis=axis)
176+
return axis
177+
178+
167179
def create_axis_reducer(
168180
scalar_op: Op,
169181
identity: Union[np.ndarray, Number],
@@ -218,6 +230,8 @@ def careduce_axis(x):
218230
219231
"""
220232

233+
axis = normalize_axis(axis, ndim)
234+
221235
reduce_elemwise_fn_name = "careduce_axis"
222236

223237
identity = str(identity)
@@ -340,6 +354,8 @@ def careduce_maximum(input):
340354
if len(axes) == 1:
341355
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
342356

357+
axes = [normalize_axis(axis, ndim) for axis in axes]
358+
343359
careduce_fn_name = f"careduce_{scalar_op}"
344360
global_env = {}
345361
to_reduce = reversed(sorted(axes))
@@ -409,6 +425,8 @@ def jit_compile_reducer(node, fn, **kwds):
409425

410426

411427
def create_axis_apply_fn(fn, axis, ndim, dtype):
428+
axis = normalize_axis(axis, ndim)
429+
412430
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
413431

414432
@numba_basic.numba_njit(boundscheck=False)
@@ -609,6 +627,8 @@ def numba_funcify_Softmax(op, node, **kwargs):
609627
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
610628
axis = op.axis
611629

630+
axis = normalize_axis(axis, x_at.ndim)
631+
612632
if axis is not None:
613633
reduce_max_py = create_axis_reducer(
614634
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
@@ -646,6 +666,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
646666
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
647667

648668
axis = op.axis
669+
axis = normalize_axis(axis, sm_at.ndim)
649670
if axis is not None:
650671
reduce_sum_py = create_axis_reducer(
651672
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
@@ -676,6 +697,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
676697
x_dtype = x_at.type.numpy_dtype
677698
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
678699
axis = op.axis
700+
axis = normalize_axis(axis, x_at.ndim)
679701

680702
if axis is not None:
681703
reduce_max_py = create_axis_reducer(

0 commit comments

Comments
 (0)