33import numpy as np
44
55from pytensor import config
6- from pytensor .link .numba .dispatch .basic import numba_funcify , numba_njit
6+ from pytensor .link .numba .dispatch import basic as numba_basic
7+ from pytensor .link .numba .dispatch .basic import numba_funcify
78from pytensor .link .numba .dispatch .linalg .decomposition .cholesky import _cholesky
89from pytensor .link .numba .dispatch .linalg .decomposition .lu import (
910 _lu_1 ,
@@ -63,7 +64,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
6364 if dtype in complex_dtypes :
6465 raise NotImplementedError (_COMPLEX_DTYPE_NOT_SUPPORTED_MSG .format (op = op ))
6566
66- @numba_njit
67+ @numba_basic . numba_njit
6768 def cholesky (a ):
6869 if check_finite :
6970 if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
@@ -95,7 +96,7 @@ def pivot_to_permutation(op, node, **kwargs):
9596 inverse = op .inverse
9697 dtype = node .outputs [0 ].dtype
9798
98- @numba_njit
99+ @numba_basic . numba_njit
99100 def numba_pivot_to_permutation (piv ):
100101 p_inv = _pivot_to_permutation (piv , dtype )
101102
@@ -118,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
118119 if dtype in complex_dtypes :
119120 NotImplementedError (_COMPLEX_DTYPE_NOT_SUPPORTED_MSG .format (op = op ))
120121
121- @numba_njit (inline = "always" )
122+ @numba_basic . numba_njit (inline = "always" )
122123 def lu (a ):
123124 if check_finite :
124125 if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
@@ -165,7 +166,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
165166 if dtype in complex_dtypes :
166167 NotImplementedError (_COMPLEX_DTYPE_NOT_SUPPORTED_MSG .format (op = op ))
167168
168- @numba_njit
169+ @numba_basic . numba_njit
169170 def lu_factor (a ):
170171 if check_finite :
171172 if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
@@ -185,7 +186,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
185186 dtype = node .outputs [0 ].dtype
186187
187188 # TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
188- @numba_njit
189+ @numba_basic . numba_njit
189190 def block_diag (* arrs ):
190191 shapes = np .array ([a .shape for a in arrs ], dtype = "int" )
191192 out_shape = [int (s ) for s in np .sum (shapes , axis = 0 )]
@@ -235,7 +236,7 @@ def numba_funcify_Solve(op, node, **kwargs):
235236 )
236237 solve_fn = _solve_gen
237238
238- @numba_njit
239+ @numba_basic . numba_njit
239240 def solve (a , b ):
240241 if check_finite :
241242 if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
@@ -267,7 +268,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
267268 _COMPLEX_DTYPE_NOT_SUPPORTED_MSG .format (op = "Solve Triangular" )
268269 )
269270
270- @numba_njit
271+ @numba_basic . numba_njit
271272 def solve_triangular (a , b ):
272273 if check_finite :
273274 if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
@@ -304,7 +305,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
304305 if dtype in complex_dtypes :
305306 raise NotImplementedError (_COMPLEX_DTYPE_NOT_SUPPORTED_MSG .format (op = op ))
306307
307- @numba_njit
308+ @numba_basic . numba_njit
308309 def cho_solve (c , b ):
309310 if check_finite :
310311 if np .any (np .bitwise_or (np .isinf (c ), np .isnan (c ))):
@@ -337,7 +338,7 @@ def numba_funcify_QR(op, node, **kwargs):
337338 integer_input = dtype in integer_dtypes
338339 in_dtype = config .floatX if integer_input else dtype
339340
340- @numba_njit (cache = False )
341+ @numba_basic . numba_njit (cache = False )
341342 def qr (a ):
342343 if check_finite :
343344 if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
0 commit comments