Skip to content

Commit 88ca2d2

Browse files
committed
Systematic use of mockable numba_basic.numba_jit
Direct import is not properly mocked by tests when trying to run `compare_numba_and_py` with `eval_obj_mode=True`
1 parent 27d7463 commit 88ca2d2

File tree

12 files changed

+55
-52
lines changed

12 files changed

+55
-52
lines changed

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from numba.core.extending import overload
55
from numba.np.unsafe.ndarray import to_fixed_tuple
66

7-
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
7+
from pytensor.link.numba.dispatch import basic as numba_basic
8+
from pytensor.link.numba.dispatch.basic import numba_funcify
89
from pytensor.link.numba.dispatch.vectorize_codegen import (
910
_jit_options,
1011
_vectorized,
@@ -56,7 +57,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
5657
src += f"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]}),"
5758
src += ")"
5859

59-
to_tuple = numba_njit(
60+
to_tuple = numba_basic.numba_njit(
6061
compile_function_src(
6162
src,
6263
"to_tuple",

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ def numba_funcify_Sum(op, node, **kwargs):
359359

360360
if ndim_input == len(axes):
361361
# Slightly faster than `numba_funcify_CAReduce` for this case
362-
@numba_njit
362+
@numba_basic.numba_njit
363363
def impl_sum(array):
364364
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
365365

366366
elif len(axes) == 0:
367367
# These cases should be removed by rewrites!
368-
@numba_njit
368+
@numba_basic.numba_njit
369369
def impl_sum(array):
370370
return np.asarray(array, dtype=out_dtype)
371371

@@ -615,25 +615,25 @@ def numba_funcify_Dot(op, node, **kwargs):
615615

616616
if x_dtype == dot_dtype and y_dtype == dot_dtype:
617617

618-
@numba_njit
618+
@numba_basic.numba_njit
619619
def dot(x, y):
620620
return np.asarray(np.dot(x, y))
621621

622622
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
623623

624-
@numba_njit
624+
@numba_basic.numba_njit
625625
def dot(x, y):
626626
return np.asarray(np.dot(x, y.astype(dot_dtype)))
627627

628628
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
629629

630-
@numba_njit
630+
@numba_basic.numba_njit
631631
def dot(x, y):
632632
return np.asarray(np.dot(x.astype(dot_dtype), y))
633633

634634
else:
635635

636-
@numba_njit()
636+
@numba_basic.numba_njit
637637
def dot(x, y):
638638
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
639639

@@ -642,7 +642,7 @@ def dot(x, y):
642642

643643
else:
644644

645-
@numba_njit
645+
@numba_basic.numba_njit
646646
def dot_with_cast(x, y):
647647
return dot(x, y).astype(out_dtype)
648648

@@ -653,7 +653,7 @@ def dot_with_cast(x, y):
653653
def numba_funcify_BatchedDot(op, node, **kwargs):
654654
dtype = node.outputs[0].type.numpy_dtype
655655

656-
@numba_njit
656+
@numba_basic.numba_njit
657657
def batched_dot(x, y):
658658
# Numba does not support 3D matmul
659659
# https://github.com/numba/numba/issues/3804

pytensor/link/numba/dispatch/linalg/decomposition/lu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@
22
from typing import Literal
33

44
import numpy as np
5-
from numba import njit as numba_njit
65
from numba.core.extending import overload
76
from numba.np.linalg import ensure_lapack
87
from scipy import linalg
98

9+
from pytensor.link.numba.dispatch import basic as numba_basic
1010
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
1111
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
1212

1313

14-
@numba_njit
14+
@numba_basic.numba_njit
1515
def _pivot_to_permutation(p, dtype):
1616
p_inv = np.arange(len(p)).astype(dtype)
1717
for i in range(len(p)):
1818
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
1919
return p_inv
2020

2121

22-
@numba_njit
22+
@numba_basic.numba_njit
2323
def _lu_factor_to_lu(a, dtype, overwrite_a):
2424
A_copy, IPIV, _INFO = _getrf(a, overwrite_a=overwrite_a)
2525

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from numpy import ndarray
77
from scipy import linalg
88

9+
from pytensor.link.numba.dispatch import basic as numba_basic
910
from pytensor.link.numba.dispatch import numba_funcify
10-
from pytensor.link.numba.dispatch.basic import numba_njit
1111
from pytensor.link.numba.dispatch.linalg._LAPACK import (
1212
_LAPACK,
1313
_get_underlying_float,
@@ -27,7 +27,7 @@
2727
)
2828

2929

30-
@numba_njit
30+
@numba_basic.numba_njit
3131
def tridiagonal_norm(du, d, dl):
3232
# Adapted from scipy _matrix_norm_tridiagonal:
3333
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
@@ -346,7 +346,7 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
346346
overwrite_d = op.overwrite_d
347347
overwrite_du = op.overwrite_du
348348

349-
@numba_njit(cache=False)
349+
@numba_basic.numba_njit(cache=False)
350350
def lu_factor_tridiagonal(dl, d, du):
351351
dl, d, du, du2, ipiv, _ = _gttrf(
352352
dl,
@@ -368,7 +368,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
368368
overwrite_b = op.overwrite_b
369369
transposed = op.transposed
370370

371-
@numba_njit(cache=False)
371+
@numba_basic.numba_njit(cache=False)
372372
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
373373
x, _ = _gttrs(
374374
dl,

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs):
3030

3131
if not compute_uv:
3232

33-
@numba_basic.numba_njit()
33+
@numba_basic.numba_njit
3434
def svd(x):
3535
_, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices)
3636
return ret
3737

3838
else:
3939

40-
@numba_basic.numba_njit()
40+
@numba_basic.numba_njit
4141
def svd(x):
4242
return np.linalg.svd(inputs_cast(x), full_matrices)
4343

pytensor/link/numba/dispatch/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def {name}(rng, {input_signature}):
9191
def numba_core_BernoulliRV(op, node):
9292
out_dtype = node.outputs[1].type.numpy_dtype
9393

94-
@numba_basic.numba_njit()
94+
@numba_basic.numba_njit
9595
def random(rng, p):
9696
return (
9797
direct_cast(0, out_dtype)

pytensor/link/numba/dispatch/shape.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from numba.np.unsafe import ndarray as numba_ndarray
55

6+
from pytensor.link.numba.dispatch import basic as numba_basic
67
from pytensor.link.numba.dispatch import numba_funcify
78
from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit
89
from pytensor.link.utils import compile_function_src
@@ -12,7 +13,7 @@
1213

1314
@numba_funcify.register(Shape)
1415
def numba_funcify_Shape(op, **kwargs):
15-
@numba_njit
16+
@numba_basic.numba_njit
1617
def shape(x):
1718
return np.asarray(np.shape(x))
1819

@@ -23,7 +24,7 @@ def shape(x):
2324
def numba_funcify_Shape_i(op, **kwargs):
2425
i = op.i
2526

26-
@numba_njit
27+
@numba_basic.numba_njit
2728
def shape_i(x):
2829
return np.asarray(np.shape(x)[i])
2930

@@ -61,13 +62,13 @@ def numba_funcify_Reshape(op, **kwargs):
6162

6263
if ndim == 0:
6364

64-
@numba_njit
65+
@numba_basic.numba_njit
6566
def reshape(x, shape):
6667
return np.asarray(x.item())
6768

6869
else:
6970

70-
@numba_njit
71+
@numba_basic.numba_njit
7172
def reshape(x, shape):
7273
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
7374
return np.reshape(

pytensor/link/numba/dispatch/signal/conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
22
from numba.np.arraymath import _get_inner_prod
33

4+
from pytensor.link.numba.dispatch import basic as numba_basic
45
from pytensor.link.numba.dispatch import numba_funcify
5-
from pytensor.link.numba.dispatch.basic import numba_njit
66
from pytensor.tensor.signal.conv import Convolve1d
77

88

@@ -13,7 +13,7 @@ def numba_funcify_Convolve1d(op, node, **kwargs):
1313
out_dtype = node.outputs[0].type.dtype
1414
innerprod = _get_inner_prod(a_dtype, b_dtype)
1515

16-
@numba_njit
16+
@numba_basic.numba_njit
1717
def valid_convolve1d(x, y):
1818
nx = len(x)
1919
ny = len(y)
@@ -30,7 +30,7 @@ def valid_convolve1d(x, y):
3030

3131
return ret
3232

33-
@numba_njit
33+
@numba_basic.numba_njit
3434
def full_convolve1d(x, y):
3535
nx = len(x)
3636
ny = len(y)
@@ -59,7 +59,7 @@ def full_convolve1d(x, y):
5959

6060
return ret
6161

62-
@numba_njit
62+
@numba_basic.numba_njit
6363
def convolve_1d(x, y, mode):
6464
if mode:
6565
return full_convolve1d(x, y)

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44

55
from pytensor import config
6-
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
6+
from pytensor.link.numba.dispatch import basic as numba_basic
7+
from pytensor.link.numba.dispatch.basic import numba_funcify
78
from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky
89
from pytensor.link.numba.dispatch.linalg.decomposition.lu import (
910
_lu_1,
@@ -63,7 +64,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
6364
if dtype in complex_dtypes:
6465
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
6566

66-
@numba_njit
67+
@numba_basic.numba_njit
6768
def cholesky(a):
6869
if check_finite:
6970
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
@@ -95,7 +96,7 @@ def pivot_to_permutation(op, node, **kwargs):
9596
inverse = op.inverse
9697
dtype = node.outputs[0].dtype
9798

98-
@numba_njit
99+
@numba_basic.numba_njit
99100
def numba_pivot_to_permutation(piv):
100101
p_inv = _pivot_to_permutation(piv, dtype)
101102

@@ -118,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
118119
if dtype in complex_dtypes:
119120
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
120121

121-
@numba_njit(inline="always")
122+
@numba_basic.numba_njit(inline="always")
122123
def lu(a):
123124
if check_finite:
124125
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
@@ -165,7 +166,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
165166
if dtype in complex_dtypes:
166167
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
167168

168-
@numba_njit
169+
@numba_basic.numba_njit
169170
def lu_factor(a):
170171
if check_finite:
171172
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
@@ -185,7 +186,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
185186
dtype = node.outputs[0].dtype
186187

187188
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
188-
@numba_njit
189+
@numba_basic.numba_njit
189190
def block_diag(*arrs):
190191
shapes = np.array([a.shape for a in arrs], dtype="int")
191192
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
@@ -235,7 +236,7 @@ def numba_funcify_Solve(op, node, **kwargs):
235236
)
236237
solve_fn = _solve_gen
237238

238-
@numba_njit
239+
@numba_basic.numba_njit
239240
def solve(a, b):
240241
if check_finite:
241242
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
@@ -267,7 +268,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
267268
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
268269
)
269270

270-
@numba_njit
271+
@numba_basic.numba_njit
271272
def solve_triangular(a, b):
272273
if check_finite:
273274
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
@@ -304,7 +305,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
304305
if dtype in complex_dtypes:
305306
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
306307

307-
@numba_njit
308+
@numba_basic.numba_njit
308309
def cho_solve(c, b):
309310
if check_finite:
310311
if np.any(np.bitwise_or(np.isinf(c), np.isnan(c))):
@@ -337,7 +338,7 @@ def numba_funcify_QR(op, node, **kwargs):
337338
integer_input = dtype in integer_dtypes
338339
in_dtype = config.floatX if integer_input else dtype
339340

340-
@numba_njit(cache=False)
341+
@numba_basic.numba_njit(cache=False)
341342
def qr(a):
342343
if check_finite:
343344
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):

pytensor/link/numba/dispatch/sort.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import numpy as np
44

5+
from pytensor.link.numba.dispatch import basic as numba_basic
56
from pytensor.link.numba.dispatch import numba_funcify
6-
from pytensor.link.numba.dispatch.basic import numba_njit
77
from pytensor.tensor.sort import ArgSortOp, SortOp
88

99

@@ -18,7 +18,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
1818
UserWarning,
1919
)
2020

21-
@numba_njit
21+
@numba_basic.numba_njit
2222
def sort_f(a, axis):
2323
axis = axis.item()
2424

@@ -45,7 +45,7 @@ def numba_funcify_ArgSortOp(op, node, **kwargs):
4545
UserWarning,
4646
)
4747

48-
@numba_njit
48+
@numba_basic.numba_njit
4949
def argort_f(X, axis):
5050
axis = axis.item()
5151

0 commit comments

Comments
 (0)