1+ from functools import partial
2+
13import numpy as np
24import numpy .linalg
35import pytest
911from pytensor import tensor as at
1012from pytensor .compile import get_default_mode
1113from pytensor .configdefaults import config
14+ from pytensor .tensor import swapaxes
1215from pytensor .tensor .blockwise import Blockwise
1316from pytensor .tensor .elemwise import DimShuffle
14- from pytensor .tensor .math import _allclose
17+ from pytensor .tensor .math import _allclose , dot , matmul
1518from pytensor .tensor .nlinalg import Det , MatrixInverse , matrix_inverse
1619from pytensor .tensor .rewriting .linalg import inv_as_solve
1720from pytensor .tensor .slinalg import Cholesky , Solve , SolveTriangular , cholesky , solve
18- from pytensor .tensor .type import dmatrix , matrix , vector
21+ from pytensor .tensor .type import dmatrix , matrix , tensor , vector
1922from tests import unittest_tools as utt
2023from tests .test_rop import break_op
2124
@@ -137,33 +140,38 @@ def test_matrix_inverse_solve():
137140@pytest .mark .parametrize ("tag" , ("lower" , "upper" , None ))
138141@pytest .mark .parametrize ("cholesky_form" , ("lower" , "upper" ))
139142@pytest .mark .parametrize ("product" , ("lower" , "upper" , None ))
140- def test_cholesky_ldotlt (tag , cholesky_form , product ):
143+ @pytest .mark .parametrize ("op" , (dot , matmul ))
144+ def test_cholesky_ldotlt (tag , cholesky_form , product , op ):
141145 transform_removes_chol = tag is not None and product == tag
142146 transform_transposes = transform_removes_chol and cholesky_form != tag
143147
144- A = matrix ("L" )
148+ ndim = 2 if op == dot else 3
149+ A = tensor ("L" , shape = (None ,) * ndim )
145150 if tag :
146151 setattr (A .tag , tag + "_triangular" , True )
147152
148153 if product == "lower" :
149- M = A . dot ( A . T )
154+ M = op ( A , swapaxes ( A , - 1 , - 2 ) )
150155 elif product == "upper" :
151- M = A . T . dot ( A )
156+ M = op ( swapaxes ( A , - 1 , - 2 ), A )
152157 else :
153158 M = A
154159
155160 C = cholesky (M , lower = (cholesky_form == "lower" ))
156161 f = pytensor .function ([A ], C , mode = get_default_mode ().including ("cholesky_ldotlt" ))
157162
158163 no_cholesky_in_graph = not any (
159- isinstance (node .op , Cholesky ) for node in f .maker .fgraph .apply_nodes
164+ isinstance (node .op , Cholesky )
165+ or (isinstance (node .op , Blockwise ) and isinstance (node .op .core_op , Cholesky ))
166+ for node in f .maker .fgraph .apply_nodes
160167 )
161168
162169 assert no_cholesky_in_graph == transform_removes_chol
163170
164171 if transform_transposes :
172+ expected_order = (1 , 0 ) if ndim == 2 else (0 , 2 , 1 )
165173 assert any (
166- isinstance (node .op , DimShuffle ) and node .op .new_order == ( 1 , 0 )
174+ isinstance (node .op , DimShuffle ) and node .op .new_order == expected_order
167175 for node in f .maker .fgraph .apply_nodes
168176 )
169177
@@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
183191 ]
184192 )
185193
194+ cholesky_vect_fn = np .vectorize (
195+ partial (scipy .linalg .cholesky , lower = (cholesky_form == "lower" )),
196+ signature = "(a, a)->(a, a)" ,
197+ )
198+
186199 for Av in Avs :
187200 if tag == "upper" :
188201 Av = Av .T
@@ -194,11 +207,13 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
194207 else :
195208 Mv = Av
196209
197- assert np .all (
198- np .isclose (
199- scipy .linalg .cholesky (Mv , lower = (cholesky_form == "lower" )),
200- f (Av ),
201- )
210+ if ndim == 3 :
211+ Av = np .broadcast_to (Av , (5 , * Av .shape ))
212+ Mv = np .broadcast_to (Mv , (5 , * Mv .shape ))
213+
214+ np .testing .assert_allclose (
215+ cholesky_vect_fn (Mv ),
216+ f (Av ),
202217 )
203218
204219
0 commit comments