Skip to content

Commit 7e9b3f8

Browse files
committed
Don't use overload for deepcopy
1 parent 3fda09c commit 7e9b3f8

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import warnings
2-
from copy import copy
32
from functools import singledispatch
43

54
import numba
65
import numpy as np
76
from numba import types
87
from numba.core.errors import NumbaWarning, TypingError
98
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
10-
from numba.extending import overload
119

1210
from pytensor import In, config
1311
from pytensor.compile import NUMBA
@@ -296,21 +294,21 @@ def numba_funcify_FunctionGraph(
296294
)
297295

298296

299-
def deepcopyop(x):
300-
return copy(x)
301-
297+
@numba_funcify.register(DeepCopyOp)
298+
def numba_funcify_DeepCopyOp(op, node, **kwargs):
299+
if isinstance(node.inputs[0].type, TensorType):
302300

303-
@overload(deepcopyop)
304-
def dispatch_deepcopyop(x):
305-
if isinstance(x, types.Array):
306-
return lambda x: np.copy(x)
301+
@numba_njit
302+
def deepcopy(x):
303+
return np.copy(x)
307304

308-
return lambda x: x
305+
else:
309306

307+
@numba_njit
308+
def deepcopy(x):
309+
return x
310310

311-
@numba_funcify.register(DeepCopyOp)
312-
def numba_funcify_DeepCopyOp(op, node, **kwargs):
313-
return deepcopyop
311+
return deepcopy
314312

315313

316314
@numba.extending.intrinsic

0 commit comments

Comments
 (0)