11import contextlib
2- import inspect
32from collections .abc import Callable , Iterable
43from typing import TYPE_CHECKING , Any
54from unittest import mock
@@ -151,38 +150,13 @@ def njit_noop(*args, **kwargs):
151150 else :
152151 return lambda x : x
153152
154- def vectorize_noop (* args , ** kwargs ):
155- def wrap (fn ):
156- # `numba.vectorize` allows an `out` positional argument. We need
157- # to account for that
158- sig = inspect .signature (fn )
159- nparams = len (sig .parameters )
160-
161- def inner_vec (* args ):
162- if len (args ) > nparams :
163- # An `out` argument has been specified for an in-place
164- # operation
165- out = args [- 1 ]
166- out [...] = np .vectorize (fn )(* args [:nparams ])
167- return out
168- else :
169- return np .vectorize (fn )(* args )
170-
171- return inner_vec
172-
173- if len (args ) == 1 and callable (args [0 ]):
174- return wrap (args [0 ], ** kwargs )
175- else :
176- return wrap
177-
178153 def py_global_numba_func (func ):
179154 if hasattr (func , "py_func" ):
180155 return func .py_func
181156 return func
182157
183158 mocks = [
184159 mock .patch ("numba.njit" , njit_noop ),
185- mock .patch ("numba.vectorize" , vectorize_noop ),
186160 mock .patch (
187161 "pytensor.link.numba.dispatch.basic.global_numba_func" ,
188162 py_global_numba_func ,
@@ -191,9 +165,6 @@ def py_global_numba_func(func):
191165 "pytensor.link.numba.dispatch.basic.tuple_setitem" , py_tuple_setitem
192166 ),
193167 mock .patch ("pytensor.link.numba.dispatch.basic.numba_njit" , njit_noop ),
194- mock .patch (
195- "pytensor.link.numba.dispatch.basic.numba_vectorize" , vectorize_noop
196- ),
197168 mock .patch (
198169 "pytensor.link.numba.dispatch.basic.direct_cast" , lambda x , dtype : x
199170 ),
0 commit comments