|
6 | 6 |
|
7 | 7 | import numba |
8 | 8 | import numpy as np |
| 9 | +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
9 | 10 |
|
10 | 11 | from pytensor import config |
11 | 12 | from pytensor.graph.basic import Apply |
@@ -164,18 +165,6 @@ def create_vectorize_func( |
164 | 165 | return elemwise_fn |
165 | 166 |
|
166 | 167 |
|
167 | | -def normalize_axis(axis, ndim): |
168 | | - if axis is None: |
169 | | - return axis |
170 | | - |
171 | | - if axis < 0: |
172 | | - axis = ndim + axis |
173 | | - |
174 | | - if axis < 0 or axis >= ndim: |
175 | | - raise np.AxisError(ndim=ndim, axis=axis) |
176 | | - return axis |
177 | | - |
178 | | - |
179 | 168 | def create_axis_reducer( |
180 | 169 | scalar_op: Op, |
181 | 170 | identity: Union[np.ndarray, Number], |
@@ -230,7 +219,7 @@ def careduce_axis(x): |
230 | 219 |
|
231 | 220 | """ |
232 | 221 |
|
233 | | - axis = normalize_axis(axis, ndim) |
| 222 | + axis = normalize_axis_index(axis, ndim) |
234 | 223 |
|
235 | 224 | reduce_elemwise_fn_name = "careduce_axis" |
236 | 225 |
|
@@ -354,7 +343,7 @@ def careduce_maximum(input): |
354 | 343 | if len(axes) == 1: |
355 | 344 | return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) |
356 | 345 |
|
357 | | - axes = [normalize_axis(axis, ndim) for axis in axes] |
| 346 | + axes = normalize_axis_tuple(axes, ndim) |
358 | 347 |
|
359 | 348 | careduce_fn_name = f"careduce_{scalar_op}" |
360 | 349 | global_env = {} |
@@ -425,7 +414,7 @@ def jit_compile_reducer(node, fn, **kwds): |
425 | 414 |
|
426 | 415 |
|
427 | 416 | def create_axis_apply_fn(fn, axis, ndim, dtype): |
428 | | - axis = normalize_axis(axis, ndim) |
| 417 | + axis = normalize_axis_index(axis, ndim) |
429 | 418 |
|
430 | 419 | reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,) |
431 | 420 |
|
@@ -627,9 +616,8 @@ def numba_funcify_Softmax(op, node, **kwargs): |
627 | 616 | x_dtype = numba.np.numpy_support.from_dtype(x_dtype) |
628 | 617 | axis = op.axis |
629 | 618 |
|
630 | | - axis = normalize_axis(axis, x_at.ndim) |
631 | | - |
632 | 619 | if axis is not None: |
| 620 | + axis = normalize_axis_index(axis, x_at.ndim) |
633 | 621 | reduce_max_py = create_axis_reducer( |
634 | 622 | scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True |
635 | 623 | ) |
@@ -666,8 +654,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): |
666 | 654 | sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype) |
667 | 655 |
|
668 | 656 | axis = op.axis |
669 | | - axis = normalize_axis(axis, sm_at.ndim) |
670 | 657 | if axis is not None: |
| 658 | + axis = normalize_axis_index(axis, sm_at.ndim) |
671 | 659 | reduce_sum_py = create_axis_reducer( |
672 | 660 | add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True |
673 | 661 | ) |
@@ -697,9 +685,9 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): |
697 | 685 | x_dtype = x_at.type.numpy_dtype |
698 | 686 | x_dtype = numba.np.numpy_support.from_dtype(x_dtype) |
699 | 687 | axis = op.axis |
700 | | - axis = normalize_axis(axis, x_at.ndim) |
701 | 688 |
|
702 | 689 | if axis is not None: |
| 690 | + axis = normalize_axis_index(axis, x_at.ndim) |
703 | 691 | reduce_max_py = create_axis_reducer( |
704 | 692 | scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True |
705 | 693 | ) |
|
0 commit comments