11use super :: * ;
22use crate :: { inner:: * , norm:: * } ;
3- use num_traits:: Zero ;
3+ use num_traits:: { One , Zero } ;
4+
5+ /// Calc a reflactor `w` from a vector `x`
6+ pub fn calc_reflector < A , S > ( x : & mut ArrayBase < S , Ix1 > )
7+ where
8+ A : Scalar + Lapack ,
9+ S : DataMut < Elem = A > ,
10+ {
11+ let norm = x. norm_l2 ( ) ;
12+ let alpha = x[ 0 ] . mul_real ( norm / x[ 0 ] . abs ( ) ) ;
13+ x[ 0 ] -= alpha;
14+ let inv_rev_norm = A :: Real :: one ( ) / x. norm_l2 ( ) ;
15+ azip ! ( mut a( x) in { * a = a. mul_real( inv_rev_norm) } ) ;
16+ }
17+
18+ /// Take a reflection using `w`
19+ pub fn reflect < A , S1 , S2 > ( w : & ArrayBase < S1 , Ix1 > , a : & mut ArrayBase < S2 , Ix1 > )
20+ where
21+ A : Scalar + Lapack ,
22+ S1 : Data < Elem = A > ,
23+ S2 : DataMut < Elem = A > ,
24+ {
25+ assert_eq ! ( w. len( ) , a. len( ) ) ;
26+ let n = a. len ( ) ;
27+ let c = A :: from ( 2.0 ) . unwrap ( ) * w. inner ( & a) ;
28+ for l in 0 ..n {
29+ a[ l] -= c * w[ l] ;
30+ }
31+ }
432
533/// Iterative orthogonalizer using Householder reflection
634#[ derive( Debug , Clone ) ]
@@ -27,13 +55,7 @@ impl<A: Scalar + Lapack> Householder<A> {
2755 {
2856 assert ! ( k < self . v. len( ) ) ;
2957 assert_eq ! ( a. len( ) , self . dim, "Input array size mismaches to the dimension" ) ;
30-
31- let w = self . v [ k] . slice ( s ! [ k..] ) ;
32- let mut a_slice = a. slice_mut ( s ! [ k..] ) ;
33- let c = A :: from ( 2.0 ) . unwrap ( ) * w. inner ( & a_slice) ;
34- for l in 0 ..self . dim - k {
35- a_slice[ l] -= c * w[ l] ;
36- }
58+ reflect ( & self . v [ k] . slice ( s ! [ k..] ) , & mut a. slice_mut ( s ! [ k..] ) ) ;
3759 }
3860
3961 /// Take forward reflection `P = P_l ... P_1`
@@ -110,14 +132,15 @@ impl<A: Scalar + Lapack> Orthogonalizer for Householder<A> {
110132 for i in 0 ..k {
111133 coef[ i] = a[ i] ;
112134 }
135+ coef[ k] = A :: from_real ( alpha) ;
113136 if alpha < rtol {
114137 // linearly dependent
115- coef[ k] = A :: from_real ( alpha) ;
116138 return Err ( coef) ;
117139 }
118140
119- // Add reflector
120141 assert ! ( k < a. len( ) ) ; // this must hold because `alpha == 0` if k >= a.len()
142+
143+ // Add reflector
121144 let alpha = if a[ k] . abs ( ) > Zero :: zero ( ) {
122145 a[ k] . mul_real ( alpha / a[ k] . abs ( ) )
123146 } else {
@@ -158,3 +181,22 @@ where
158181 let h = Householder :: new ( dim) ;
159182 qr ( iter, h, rtol, strategy)
160183}
184+
185+ #[ cfg( test) ]
186+ mod tests {
187+ use super :: * ;
188+ use crate :: assert:: * ;
189+
190+ #[ test]
191+ fn check_reflector ( ) {
192+ let mut a = array ! [ c64:: new( 1.0 , 1.0 ) , c64:: new( 1.0 , 0.0 ) , c64:: new( 0.0 , 1.0 ) ] ;
193+ let mut w = a. clone ( ) ;
194+ calc_reflector ( & mut w) ;
195+ reflect ( & w, & mut a) ;
196+ close_l2 (
197+ & a,
198+ & array ! [ c64:: new( 2.0 . sqrt( ) , 2.0 . sqrt( ) ) , c64:: zero( ) , c64:: zero( ) ] ,
199+ 1e-9 ,
200+ ) ;
201+ }
202+ }
0 commit comments