11//! Implement linear solver using LU decomposition
22//! for tridiagonal matrix
33
4+ mod matrix;
5+
6+ pub use matrix:: * ;
7+
48use crate :: { error:: * , layout:: * , * } ;
59use cauchy:: * ;
610use num_traits:: Zero ;
7- use std:: ops:: { Index , IndexMut } ;
8-
9- /// Represents a tridiagonal matrix as 3 one-dimensional vectors.
10- ///
11- /// ```text
12- /// [d0, u1, 0, ..., 0,
13- /// l1, d1, u2, ...,
14- /// 0, l2, d2,
15- /// ... ..., u{n-1},
16- /// 0, ..., l{n-1}, d{n-1},]
17- /// ```
18- #[ derive( Clone , PartialEq , Eq ) ]
19- pub struct Tridiagonal < A : Scalar > {
20- /// layout of raw matrix
21- pub l : MatrixLayout ,
22- /// (n-1) sub-diagonal elements of matrix.
23- pub dl : Vec < A > ,
24- /// (n) diagonal elements of matrix.
25- pub d : Vec < A > ,
26- /// (n-1) super-diagonal elements of matrix.
27- pub du : Vec < A > ,
28- }
29-
30- impl < A : Scalar > Tridiagonal < A > {
31- fn opnorm_one ( & self ) -> A :: Real {
32- let mut col_sum: Vec < A :: Real > = self . d . iter ( ) . map ( |val| val. abs ( ) ) . collect ( ) ;
33- for i in 0 ..col_sum. len ( ) {
34- if i < self . dl . len ( ) {
35- col_sum[ i] += self . dl [ i] . abs ( ) ;
36- }
37- if i > 0 {
38- col_sum[ i] += self . du [ i - 1 ] . abs ( ) ;
39- }
40- }
41- let mut max = A :: Real :: zero ( ) ;
42- for & val in & col_sum {
43- if max < val {
44- max = val;
45- }
46- }
47- max
48- }
49- }
5011
5112/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
5213#[ derive( Clone , PartialEq ) ]
@@ -65,66 +26,6 @@ pub struct LUFactorizedTridiagonal<A: Scalar> {
6526 a_opnorm_one : A :: Real ,
6627}
6728
68- impl < A : Scalar > Index < ( i32 , i32 ) > for Tridiagonal < A > {
69- type Output = A ;
70- #[ inline]
71- fn index ( & self , ( row, col) : ( i32 , i32 ) ) -> & A {
72- let ( n, _) = self . l . size ( ) ;
73- assert ! (
74- std:: cmp:: max( row, col) < n,
75- "ndarray: index {:?} is out of bounds for array of shape {}" ,
76- [ row, col] ,
77- n
78- ) ;
79- match row - col {
80- 0 => & self . d [ row as usize ] ,
81- 1 => & self . dl [ col as usize ] ,
82- -1 => & self . du [ row as usize ] ,
83- _ => panic ! (
84- "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element" ,
85- [ row, col]
86- ) ,
87- }
88- }
89- }
90-
91- impl < A : Scalar > Index < [ i32 ; 2 ] > for Tridiagonal < A > {
92- type Output = A ;
93- #[ inline]
94- fn index ( & self , [ row, col] : [ i32 ; 2 ] ) -> & A {
95- & self [ ( row, col) ]
96- }
97- }
98-
99- impl < A : Scalar > IndexMut < ( i32 , i32 ) > for Tridiagonal < A > {
100- #[ inline]
101- fn index_mut ( & mut self , ( row, col) : ( i32 , i32 ) ) -> & mut A {
102- let ( n, _) = self . l . size ( ) ;
103- assert ! (
104- std:: cmp:: max( row, col) < n,
105- "ndarray: index {:?} is out of bounds for array of shape {}" ,
106- [ row, col] ,
107- n
108- ) ;
109- match row - col {
110- 0 => & mut self . d [ row as usize ] ,
111- 1 => & mut self . dl [ col as usize ] ,
112- -1 => & mut self . du [ row as usize ] ,
113- _ => panic ! (
114- "ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element" ,
115- [ row, col]
116- ) ,
117- }
118- }
119- }
120-
121- impl < A : Scalar > IndexMut < [ i32 ; 2 ] > for Tridiagonal < A > {
122- #[ inline]
123- fn index_mut ( & mut self , [ row, col] : [ i32 ; 2 ] ) -> & mut A {
124- & mut self [ ( row, col) ]
125- }
126- }
127-
12829/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
12930pub trait Tridiagonal_ : Scalar + Sized {
13031 /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
0 commit comments