@@ -164,6 +164,18 @@ def create_vectorize_func(
164164 return elemwise_fn
165165
166166
167+ def normalize_axis (axis , ndim ):
168+ if axis is None :
169+ return axis
170+
171+ if axis < 0 :
172+ axis = ndim + axis
173+
174+ if axis < 0 or axis >= ndim :
175+ raise np .AxisError (ndim = ndim , axis = axis )
176+ return axis
177+
178+
167179def create_axis_reducer (
168180 scalar_op : Op ,
169181 identity : Union [np .ndarray , Number ],
@@ -218,6 +230,8 @@ def careduce_axis(x):
218230
219231 """
220232
233+ axis = normalize_axis (axis , ndim )
234+
221235 reduce_elemwise_fn_name = "careduce_axis"
222236
223237 identity = str (identity )
@@ -340,6 +354,8 @@ def careduce_maximum(input):
340354 if len (axes ) == 1 :
341355 return create_axis_reducer (scalar_op , identity , axes [0 ], ndim , dtype )
342356
357+ axes = [normalize_axis (axis , ndim ) for axis in axes ]
358+
343359 careduce_fn_name = f"careduce_{ scalar_op } "
344360 global_env = {}
345361 to_reduce = reversed (sorted (axes ))
@@ -409,6 +425,8 @@ def jit_compile_reducer(node, fn, **kwds):
409425
410426
411427def create_axis_apply_fn (fn , axis , ndim , dtype ):
428+ axis = normalize_axis (axis , ndim )
429+
412430 reaxis_first = tuple (i for i in range (ndim ) if i != axis ) + (axis ,)
413431
414432 @numba_basic .numba_njit (boundscheck = False )
@@ -609,6 +627,8 @@ def numba_funcify_Softmax(op, node, **kwargs):
609627 x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
610628 axis = op .axis
611629
630+ axis = normalize_axis (axis , x_at .ndim )
631+
612632 if axis is not None :
613633 reduce_max_py = create_axis_reducer (
614634 scalar_maximum , - np .inf , axis , x_at .ndim , x_dtype , keepdims = True
@@ -646,6 +666,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
646666 sm_dtype = numba .np .numpy_support .from_dtype (sm_dtype )
647667
648668 axis = op .axis
669+ axis = normalize_axis (axis , sm_at .ndim )
649670 if axis is not None :
650671 reduce_sum_py = create_axis_reducer (
651672 add_as , 0.0 , axis , sm_at .ndim , sm_dtype , keepdims = True
@@ -676,6 +697,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
676697 x_dtype = x_at .type .numpy_dtype
677698 x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
678699 axis = op .axis
700+ axis = normalize_axis (axis , x_at .ndim )
679701
680702 if axis is not None :
681703 reduce_max_py = create_axis_reducer (
0 commit comments