|
8 | 8 | import numba |
9 | 9 | import numba.np.unsafe.ndarray as numba_ndarray |
10 | 10 | import numpy as np |
11 | | -import scipy |
12 | | -import scipy.special |
13 | 11 | from llvmlite import ir |
14 | 12 | from numba import types |
15 | 13 | from numba.core.errors import NumbaWarning, TypingError |
|
36 | 34 | from pytensor.tensor.blas import BatchedDot |
37 | 35 | from pytensor.tensor.math import Dot |
38 | 36 | from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape |
39 | | -from pytensor.tensor.slinalg import Solve |
40 | 37 | from pytensor.tensor.sort import ArgSortOp, SortOp |
41 | 38 | from pytensor.tensor.type import TensorType |
42 | 39 | from pytensor.tensor.type_other import MakeSlice, NoneConst |
@@ -626,51 +623,6 @@ def dot_with_cast(x, y): |
626 | 623 | return dot_with_cast |
627 | 624 |
|
628 | 625 |
|
629 | | -@numba_funcify.register(Solve) |
630 | | -def numba_funcify_Solve(op, node, **kwargs): |
631 | | - assume_a = op.assume_a |
632 | | - # check_finite = op.check_finite |
633 | | - |
634 | | - if assume_a != "gen": |
635 | | - lower = op.lower |
636 | | - |
637 | | - warnings.warn( |
638 | | - ( |
639 | | - "Numba will use object mode to allow the " |
640 | | - "`compute_uv` argument to `numpy.linalg.svd`." |
641 | | - ), |
642 | | - UserWarning, |
643 | | - ) |
644 | | - |
645 | | - ret_sig = get_numba_type(node.outputs[0].type) |
646 | | - |
647 | | - @numba_njit |
648 | | - def solve(a, b): |
649 | | - with numba.objmode(ret=ret_sig): |
650 | | - ret = scipy.linalg.solve_triangular( |
651 | | - a, |
652 | | - b, |
653 | | - lower=lower, |
654 | | - # check_finite=check_finite |
655 | | - ) |
656 | | - return ret |
657 | | - |
658 | | - else: |
659 | | - out_dtype = node.outputs[0].type.numpy_dtype |
660 | | - inputs_cast = int_to_float_fn(node.inputs, out_dtype) |
661 | | - |
662 | | - @numba_njit |
663 | | - def solve(a, b): |
664 | | - return np.linalg.solve( |
665 | | - inputs_cast(a), |
666 | | - inputs_cast(b), |
667 | | - # assume_a=assume_a, |
668 | | - # check_finite=check_finite, |
669 | | - ).astype(out_dtype) |
670 | | - |
671 | | - return solve |
672 | | - |
673 | | - |
674 | 626 | @numba_funcify.register(BatchedDot) |
675 | 627 | def numba_funcify_BatchedDot(op, node, **kwargs): |
676 | 628 | dtype = node.outputs[0].type.numpy_dtype |
|
0 commit comments