Skip to content

Commit 37da5f6

Browse files
committed
Remove duplicated Solve dispatch
1 parent a4a2ac0 commit 37da5f6

File tree

1 file changed

+0
-48
lines changed

1 file changed

+0
-48
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import numba
99
import numba.np.unsafe.ndarray as numba_ndarray
1010
import numpy as np
11-
import scipy
12-
import scipy.special
1311
from llvmlite import ir
1412
from numba import types
1513
from numba.core.errors import NumbaWarning, TypingError
@@ -36,7 +34,6 @@
3634
from pytensor.tensor.blas import BatchedDot
3735
from pytensor.tensor.math import Dot
3836
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
39-
from pytensor.tensor.slinalg import Solve
4037
from pytensor.tensor.sort import ArgSortOp, SortOp
4138
from pytensor.tensor.type import TensorType
4239
from pytensor.tensor.type_other import MakeSlice, NoneConst
@@ -626,51 +623,6 @@ def dot_with_cast(x, y):
626623
return dot_with_cast
627624

628625

629-
@numba_funcify.register(Solve)
630-
def numba_funcify_Solve(op, node, **kwargs):
631-
assume_a = op.assume_a
632-
# check_finite = op.check_finite
633-
634-
if assume_a != "gen":
635-
lower = op.lower
636-
637-
warnings.warn(
638-
(
639-
"Numba will use object mode to allow the "
640-
"`compute_uv` argument to `numpy.linalg.svd`."
641-
),
642-
UserWarning,
643-
)
644-
645-
ret_sig = get_numba_type(node.outputs[0].type)
646-
647-
@numba_njit
648-
def solve(a, b):
649-
with numba.objmode(ret=ret_sig):
650-
ret = scipy.linalg.solve_triangular(
651-
a,
652-
b,
653-
lower=lower,
654-
# check_finite=check_finite
655-
)
656-
return ret
657-
658-
else:
659-
out_dtype = node.outputs[0].type.numpy_dtype
660-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
661-
662-
@numba_njit
663-
def solve(a, b):
664-
return np.linalg.solve(
665-
inputs_cast(a),
666-
inputs_cast(b),
667-
# assume_a=assume_a,
668-
# check_finite=check_finite,
669-
).astype(out_dtype)
670-
671-
return solve
672-
673-
674626
@numba_funcify.register(BatchedDot)
675627
def numba_funcify_BatchedDot(op, node, **kwargs):
676628
dtype = node.outputs[0].type.numpy_dtype

0 commit comments

Comments
 (0)