Skip to content

Commit 3198d94

Browse files
committed
Disable numba cache for cython functions
1 parent 3ed908b commit 3198d94

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,14 @@
4848

4949
def numba_njit(*args, **kwargs):
5050

51+
kwargs = kwargs.copy()
52+
if "cache" not in kwargs:
53+
kwargs["cache"] = config.numba__cache
54+
5155
if len(args) > 0 and callable(args[0]):
52-
return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
56+
return numba.njit(*args[1:], **kwargs)(args[0])
5357

54-
return numba.njit(*args, cache=config.numba__cache, **kwargs)
58+
return numba.njit(*args, **kwargs)
5559

5660

5761
def numba_vectorize(*args, **kwargs):

pytensor/link/numba/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
144144
signature = create_numba_signature(node, force_scalar=True)
145145

146146
return numba_basic.numba_njit(
147-
signature, inline="always", fastmath=config.numba__fastmath
147+
signature, inline="always", fastmath=config.numba__fastmath, cache=False,
148148
)(scalar_op_fn)
149149

150150

0 commit comments

Comments
 (0)