66from numpy import ndarray
77from scipy import linalg
88
9+ from pytensor .link .numba .dispatch import numba_funcify
910from pytensor .link .numba .dispatch .basic import numba_njit
1011from pytensor .link .numba .dispatch .linalg ._LAPACK import (
1112 _LAPACK ,
2021 _solve_check ,
2122 _trans_char_to_int ,
2223)
24+ from pytensor .tensor ._linalg .solve .tridiagonal import (
25+ LUFactorTridiagonal ,
26+ SolveLUFactorTridiagonal ,
27+ )
2328
2429
2530@numba_njit
@@ -34,7 +39,12 @@ def tridiagonal_norm(du, d, dl):
3439
3540
3641def _gttrf (
37- dl : ndarray , d : ndarray , du : ndarray
42+ dl : ndarray ,
43+ d : ndarray ,
44+ du : ndarray ,
45+ overwrite_dl : bool ,
46+ overwrite_d : bool ,
47+ overwrite_du : bool ,
3848) -> tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]:
3949 """Placeholder for LU factorization of tridiagonal matrix."""
4050 return # type: ignore
@@ -45,8 +55,12 @@ def gttrf_impl(
4555 dl : ndarray ,
4656 d : ndarray ,
4757 du : ndarray ,
58+ overwrite_dl : bool ,
59+ overwrite_d : bool ,
60+ overwrite_du : bool ,
4861) -> Callable [
49- [ndarray , ndarray , ndarray ], tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]
62+ [ndarray , ndarray , ndarray , bool , bool , bool ],
63+ tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ],
5064]:
5165 ensure_lapack ()
5266 _check_scipy_linalg_matrix (dl , "gttrf" )
@@ -60,12 +74,24 @@ def impl(
6074 dl : ndarray ,
6175 d : ndarray ,
6276 du : ndarray ,
77+ overwrite_dl : bool ,
78+ overwrite_d : bool ,
79+ overwrite_du : bool ,
6380 ) -> tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]:
6481 n = np .int32 (d .shape [- 1 ])
6582 ipiv = np .empty (n , dtype = np .int32 )
6683 du2 = np .empty (n - 2 , dtype = dtype )
6784 info = val_to_int_ptr (0 )
6885
86+ if not overwrite_dl or not dl .flags .f_contiguous :
87+ dl = dl .copy ()
88+
89+ if not overwrite_d or not d .flags .f_contiguous :
90+ d = d .copy ()
91+
92+ if not overwrite_du or not du .flags .f_contiguous :
93+ du = du .copy ()
94+
6995 numba_gttrf (
7096 val_to_int_ptr (n ),
7197 dl .view (w_type ).ctypes ,
@@ -133,10 +159,23 @@ def impl(
133159 nrhs = 1 if b .ndim == 1 else int (b .shape [- 1 ])
134160 info = val_to_int_ptr (0 )
135161
136- if overwrite_b and b .flags .f_contiguous :
137- b_copy = b
138- else :
139- b_copy = _copy_to_fortran_order_even_if_1d (b )
162+ if not overwrite_b or not b .flags .f_contiguous :
163+ b = _copy_to_fortran_order_even_if_1d (b )
164+
165+ if not dl .flags .f_contiguous :
166+ dl = dl .copy ()
167+
168+ if not d .flags .f_contiguous :
169+ d = d .copy ()
170+
171+ if not du .flags .f_contiguous :
172+ du = du .copy ()
173+
174+ if not du2 .flags .f_contiguous :
175+ du2 = du2 .copy ()
176+
177+ if not ipiv .flags .f_contiguous :
178+ ipiv = ipiv .copy ()
140179
141180 numba_gttrs (
142181 val_to_int_ptr (_trans_char_to_int (trans )),
@@ -147,12 +186,12 @@ def impl(
147186 du .view (w_type ).ctypes ,
148187 du2 .view (w_type ).ctypes ,
149188 ipiv .ctypes ,
150- b_copy .view (w_type ).ctypes ,
189+ b .view (w_type ).ctypes ,
151190 val_to_int_ptr (n ),
152191 info ,
153192 )
154193
155- return b_copy , int_ptr_to_val (info )
194+ return b , int_ptr_to_val (info )
156195
157196 return impl
158197
@@ -283,7 +322,9 @@ def impl(
283322
284323 anorm = tridiagonal_norm (du , d , dl )
285324
286- dl , d , du , du2 , IPIV , INFO = _gttrf (dl , d , du )
325+ dl , d , du , du2 , IPIV , INFO = _gttrf (
326+ dl , d , du , overwrite_dl = True , overwrite_d = True , overwrite_du = True
327+ )
287328 _solve_check (n , INFO )
288329
289330 X , INFO = _gttrs (
@@ -297,3 +338,48 @@ def impl(
297338 return X
298339
299340 return impl
341+
342+
343+ @numba_funcify .register (LUFactorTridiagonal )
344+ def numba_funcify_LUFactorTridiagonal (op : LUFactorTridiagonal , node , ** kwargs ):
345+ overwrite_dl = op .overwrite_dl
346+ overwrite_d = op .overwrite_d
347+ overwrite_du = op .overwrite_du
348+
349+ @numba_njit (cache = False )
350+ def lu_factor_tridiagonal (dl , d , du ):
351+ dl , d , du , du2 , ipiv , _ = _gttrf (
352+ dl ,
353+ d ,
354+ du ,
355+ overwrite_dl = overwrite_dl ,
356+ overwrite_d = overwrite_d ,
357+ overwrite_du = overwrite_du ,
358+ )
359+ return dl , d , du , du2 , ipiv
360+
361+ return lu_factor_tridiagonal
362+
363+
364+ @numba_funcify .register (SolveLUFactorTridiagonal )
365+ def numba_funcify_SolveLUFactorTridiagonal (
366+ op : SolveLUFactorTridiagonal , node , ** kwargs
367+ ):
368+ overwrite_b = op .overwrite_b
369+ transposed = op .transposed
370+
371+ @numba_njit (cache = False )
372+ def solve_lu_factor_tridiagonal (dl , d , du , du2 , ipiv , b ):
373+ x , _ = _gttrs (
374+ dl ,
375+ d ,
376+ du ,
377+ du2 ,
378+ ipiv ,
379+ b ,
380+ overwrite_b = overwrite_b ,
381+ trans = transposed ,
382+ )
383+ return x
384+
385+ return solve_lu_factor_tridiagonal
0 commit comments