File tree Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -122,6 +122,38 @@ def test_jax_solve():
122122 )
123123
124124
125+ def test_jax_tridiagonal_solve ():
126+ N = 10
127+ A = pt .matrix ("A" , shape = (N , N ))
128+ b = pt .vector ("b" , shape = (N ,))
129+
130+ out = pt .linalg .solve (A , b , assume_a = "tridiagonal" )
131+
132+ A_val = np .eye (N )
133+ for i in range (N - 1 ):
134+ A_val [i , i + 1 ] = np .random .randn ()
135+ A_val [i + 1 , i ] = np .random .randn ()
136+
137+ b_val = np .random .randn (N )
138+
139+ compare_jax_and_py (
140+ [A , b ],
141+ [out ],
142+ [A_val , b_val ],
143+ )
144+
145+ b_ = pt .matrix ("b" , shape = (N , 2 ))
146+
147+ out = pt .linalg .solve (A , b_ , assume_a = "tridiagonal" )
148+ b_val = np .random .randn (N , 2 )
149+
150+ compare_jax_and_py (
151+ [A , b_ ],
152+ [out ],
153+ [A_val , b_val ],
154+ )
155+
156+
125157def test_jax_SolveTriangular ():
126158 rng = np .random .default_rng (utt .fetch_seed ())
127159
You can’t perform that action at this time.
0 commit comments