Skip to content

Commit aa1b7c8

Browse files
committed
Remove unused numba_vectorize
1 parent 022f189 commit aa1b7c8

File tree

2 files changed

+0
-36
lines changed

2 files changed

+0
-36
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,6 @@ def numba_njit(*args, fastmath=None, **kwargs):
8787
return numba.njit(*args, fastmath=fastmath, **kwargs)
8888

8989

90-
def numba_vectorize(*args, **kwargs):
91-
if len(args) > 0 and callable(args[0]):
92-
return numba.vectorize(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
93-
94-
return numba.vectorize(*args, cache=config.numba__cache, **kwargs)
95-
96-
9790
def get_numba_type(
9891
pytensor_type: Type,
9992
layout: str = "A",

tests/link/numba/test_basic.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
import inspect
32
from collections.abc import Callable, Iterable
43
from typing import TYPE_CHECKING, Any
54
from unittest import mock
@@ -151,38 +150,13 @@ def njit_noop(*args, **kwargs):
151150
else:
152151
return lambda x: x
153152

154-
def vectorize_noop(*args, **kwargs):
155-
def wrap(fn):
156-
# `numba.vectorize` allows an `out` positional argument. We need
157-
# to account for that
158-
sig = inspect.signature(fn)
159-
nparams = len(sig.parameters)
160-
161-
def inner_vec(*args):
162-
if len(args) > nparams:
163-
# An `out` argument has been specified for an in-place
164-
# operation
165-
out = args[-1]
166-
out[...] = np.vectorize(fn)(*args[:nparams])
167-
return out
168-
else:
169-
return np.vectorize(fn)(*args)
170-
171-
return inner_vec
172-
173-
if len(args) == 1 and callable(args[0]):
174-
return wrap(args[0], **kwargs)
175-
else:
176-
return wrap
177-
178153
def py_global_numba_func(func):
179154
if hasattr(func, "py_func"):
180155
return func.py_func
181156
return func
182157

183158
mocks = [
184159
mock.patch("numba.njit", njit_noop),
185-
mock.patch("numba.vectorize", vectorize_noop),
186160
mock.patch(
187161
"pytensor.link.numba.dispatch.basic.global_numba_func",
188162
py_global_numba_func,
@@ -191,9 +165,6 @@ def py_global_numba_func(func):
191165
"pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem
192166
),
193167
mock.patch("pytensor.link.numba.dispatch.basic.numba_njit", njit_noop),
194-
mock.patch(
195-
"pytensor.link.numba.dispatch.basic.numba_vectorize", vectorize_noop
196-
),
197168
mock.patch(
198169
"pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x
199170
),

0 commit comments

Comments
 (0)