File tree Expand file tree Collapse file tree 1 file changed +11
-13
lines changed
pytensor/link/numba/dispatch Expand file tree Collapse file tree 1 file changed +11
-13
lines changed Original file line number Diff line number Diff line change 11import warnings
2- from copy import copy
32from functools import singledispatch
43
54import numba
65import numpy as np
76from numba import types
87from numba .core .errors import NumbaWarning , TypingError
98from numba .cpython .unsafe .tuple import tuple_setitem # noqa: F401
10- from numba .extending import overload
119
1210from pytensor import In , config
1311from pytensor .compile import NUMBA
@@ -296,21 +294,21 @@ def numba_funcify_FunctionGraph(
296294 )
297295
298296
299- def deepcopyop ( x ):
300- return copy ( x )
301-
297+ @ numba_funcify . register ( DeepCopyOp )
298+ def numba_funcify_DeepCopyOp ( op , node , ** kwargs ):
299+ if isinstance ( node . inputs [ 0 ]. type , TensorType ):
302300
303- @overload (deepcopyop )
304- def dispatch_deepcopyop (x ):
305- if isinstance (x , types .Array ):
306- return lambda x : np .copy (x )
301+ @numba_njit
302+ def deepcopy (x ):
303+ return np .copy (x )
307304
308- return lambda x : x
305+ else :
309306
307+ @numba_njit
308+ def deepcopy (x ):
309+ return x
310310
311- @numba_funcify .register (DeepCopyOp )
312- def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
313- return deepcopyop
311+ return deepcopy
314312
315313
316314@numba .extending .intrinsic
You can’t perform that action at this time.
0 commit comments