Skip to content

Commit a1e19e7

Browse files
committed
use new decorator pattern, lapack trtri
1 parent 7c11584 commit a1e19e7

File tree

3 files changed

+72
-16
lines changed

3 files changed

+72
-16
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
ExtractDiag,
2121
Eye,
2222
TensorVariable,
23-
Tri,
2423
concatenate,
2524
diag,
2625
diagonal,
@@ -53,6 +52,7 @@
5352
Solve,
5453
SolveBase,
5554
SolveTriangular,
55+
TriangularInv,
5656
_bilinear_solve_discrete_lyapunov,
5757
block_diag,
5858
cholesky,
@@ -1033,14 +1033,6 @@ def _find_triangular_op(var):
10331033
if is_lower or is_upper:
10341034
return (is_lower, is_upper)
10351035

1036-
if var.owner and isinstance(var.owner.op, Tri):
1037-
# The 'k' parameter of Tri determines the diagonal.
1038-
# k=0 is the main diagonal.
1039-
k = var.owner.op.k
1040-
if k == 0:
1041-
is_lower = var.owner.op.lower
1042-
return (is_lower, not is_lower)
1043-
10441036
if var.owner and isinstance(var.owner.op, Blockwise):
10451037
core_op = var.owner.op.core_op
10461038
if isinstance(core_op, Cholesky):
@@ -1051,16 +1043,13 @@ def _find_triangular_op(var):
10511043

10521044
@register_canonicalize
10531045
@register_stabilize
1054-
@node_rewriter([Blockwise])
1046+
@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)])
10551047
def rewrite_inv_to_triangular_solve(fgraph, node):
10561048
"""
10571049
This rewrite takes advantage of the fact that the inverse of a triangular
10581050
matrix can be computed more efficiently than the inverse of a general
10591051
matrix by using a triangular solve instead of a general matrix inverse.
10601052
"""
1061-
core_op = node.op.core_op
1062-
if not isinstance(core_op, ALL_INVERSE_OPS):
1063-
return None
10641053

10651054
A = node.inputs[0]
10661055
triangular_info = _find_triangular_op(A)
@@ -1069,5 +1058,5 @@ def rewrite_inv_to_triangular_solve(fgraph, node):
10691058

10701059
is_lower, is_upper = triangular_info
10711060
if is_lower or is_upper:
1072-
I = pt.eye(A.shape[0], dtype=A.dtype)
1073-
return [solve_triangular(A, I, lower=is_lower)]
1061+
new_op = TriangularInv(lower=is_lower)
1062+
return [new_op(A)]

pytensor/tensor/slinalg.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytensor.tensor import math as ptm
2121
from pytensor.tensor.basic import as_tensor_variable, diagonal
2222
from pytensor.tensor.blockwise import Blockwise
23-
from pytensor.tensor.nlinalg import kron, matrix_dot
23+
from pytensor.tensor.nlinalg import MatrixInverse, kron, matrix_dot
2424
from pytensor.tensor.shape import reshape
2525
from pytensor.tensor.type import matrix, tensor, vector
2626
from pytensor.tensor.variable import TensorVariable
@@ -1016,6 +1016,30 @@ def solve_triangular(
10161016
return cast(TensorVariable, ret)
10171017

10181018

1019+
class TriangularInv(MatrixInverse):
1020+
"""
1021+
Computes the inverse of a triangular matrix.
1022+
"""
1023+
1024+
__props__ = ("lower",)
1025+
1026+
def __init__(self, lower=True):
1027+
self.lower = lower
1028+
1029+
def perform(self, node, inputs, outputs):
1030+
(x,) = inputs
1031+
(z,) = outputs
1032+
(dtrtri,) = get_lapack_funcs(("trtri",), (x,))
1033+
inv, info = dtrtri(x, lower=self.lower, overwrite_c=True)
1034+
if info > 0:
1035+
raise np.linalg.LinAlgError("Singular matrix")
1036+
elif info < 0:
1037+
raise ValueError(
1038+
"illegal value in %d-th argument of internal trtri" % -info
1039+
)
1040+
z[0] = inv
1041+
1042+
10191043
class Solve(SolveBase):
10201044
"""
10211045
Solve a system of linear equations.

tests/tensor/rewriting/test_linalg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
MatrixInverse,
2424
MatrixPinv,
2525
SLogDet,
26+
inv,
2627
matrix_inverse,
2728
svd,
2829
)
@@ -34,6 +35,7 @@
3435
Solve,
3536
SolveBase,
3637
SolveTriangular,
38+
TriangularInv,
3739
cho_solve,
3840
cholesky,
3941
solve,
@@ -1060,3 +1062,44 @@ def solve_op_in_graph(graph):
10601062
np.testing.assert_allclose(
10611063
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
10621064
)
1065+
1066+
1067+
def test_triangular_inv_op():
1068+
x = matrix("x")
1069+
f_lower = function([x], Blockwise(TriangularInv(lower=True))(x))
1070+
f_upper = function([x], Blockwise(TriangularInv(lower=False))(x))
1071+
1072+
# Test lower
1073+
a = np.tril(np.random.rand(5, 5) + 0.1)
1074+
a_inv = f_lower(a)
1075+
expected_inv = np.linalg.inv(a)
1076+
np.testing.assert_allclose(
1077+
np.tril(a_inv), np.tril(expected_inv), rtol=1e-5, atol=1e-7
1078+
)
1079+
1080+
# Test upper
1081+
a = np.triu(np.random.rand(5, 5) + 0.1)
1082+
a_inv = f_upper(a)
1083+
expected_inv = np.linalg.inv(a)
1084+
np.testing.assert_allclose(
1085+
np.triu(a_inv), np.triu(expected_inv), rtol=1e-5, atol=1e-7
1086+
)
1087+
1088+
1089+
def test_inv_to_triangular_inv_rewrite():
1090+
x = matrix("x")
1091+
1092+
x_chol = cholesky(x)
1093+
y_chol = inv(x_chol)
1094+
f_chol = function([x], y_chol)
1095+
assert any(
1096+
isinstance(node.op, TriangularInv)
1097+
or (hasattr(node.op, "core_op") and isinstance(node.op.core_op, TriangularInv))
1098+
for node in f_chol.maker.fgraph.apply_nodes
1099+
)
1100+
1101+
a = np.random.rand(5, 5)
1102+
a = np.dot(a, a.T) + np.eye(5) * 0.1 # ensure positive definite
1103+
np.testing.assert_allclose(
1104+
f_chol(a), np.linalg.inv(np.linalg.cholesky(a)), rtol=1e-5, atol=1e-7
1105+
)

0 commit comments

Comments
 (0)