@@ -8,7 +8,7 @@ use crate::{Lapack, Scalar};
88use ndarray:: prelude:: * ;
99use ndarray:: OwnedRepr ;
1010use ndarray:: ScalarOperand ;
11- use num_traits:: { NumCast , Float } ;
11+ use num_traits:: { Float , NumCast } ;
1212
1313/// Find largest or smallest eigenvalues
1414#[ derive( Debug , Clone ) ]
@@ -20,12 +20,12 @@ pub enum Order {
2020/// The result of the eigensolver
2121///
2222/// In the best case the eigensolver has converged with a result better than the given threshold,
23- /// then a `EigResult ::Ok` gives the eigenvalues, eigenvectors and norms. If an error ocurred
24- /// during the process, it is returned in `EigResult ::Err`, but the best result is still returned,
25- /// as it could be usable. If there is no result at all, then `EigResult ::NoResult` is returned.
23+ /// then a `LobpcgResult ::Ok` gives the eigenvalues, eigenvectors and norms. If an error ocurred
24+ /// during the process, it is returned in `LobpcgResult ::Err`, but the best result is still returned,
25+ /// as it could be usable. If there is no result at all, then `LobpcgResult ::NoResult` is returned.
2626/// This happens if the algorithm fails in an early stage, for example if the matrix `A` is not SPD
2727#[ derive( Debug ) ]
28- pub enum EigResult < A > {
28+ pub enum LobpcgResult < A > {
2929 Ok ( Array1 < A > , Array2 < A > , Vec < A > ) ,
3030 Err ( Array1 < A > , Array2 < A > , Vec < A > , LinalgError ) ,
3131 NoResult ( LinalgError ) ,
@@ -61,16 +61,18 @@ fn sorted_eig<A: Scalar + Lapack>(
6161fn ndarray_mask < A : Scalar > ( matrix : ArrayView2 < A > , mask : & [ bool ] ) -> Array2 < A > {
6262 assert_eq ! ( mask. len( ) , matrix. ncols( ) ) ;
6363
64- let indices = ( 0 ..mask. len ( ) ) . zip ( mask. into_iter ( ) )
65- . filter ( |( _, b) | * * b) . map ( |( a, _) | a)
64+ let indices = ( 0 ..mask. len ( ) )
65+ . zip ( mask. into_iter ( ) )
66+ . filter ( |( _, b) | * * b)
67+ . map ( |( a, _) | a)
6668 . collect :: < Vec < usize > > ( ) ;
6769
6870 matrix. select ( Axis ( 1 ) , & indices)
6971}
7072
7173/// Applies constraints ensuring that a matrix is orthogonal to it
7274///
73- /// This functions takes a matrix `v` and constraint matrix `y` and orthogonalize the `v` to `y`.
75+ /// This functions takes a matrix `v` and constraint- matrix `y` and orthogonalize `v` to `y`.
7476fn apply_constraints < A : Scalar + Lapack > (
7577 mut v : ArrayViewMut < A , Ix2 > ,
7678 cholesky_yy : & CholeskyFactorized < OwnedRepr < A > > ,
@@ -132,19 +134,23 @@ fn orthonormalize<T: Scalar + Lapack>(v: Array2<T>) -> Result<(Array2<T>, Array2
132134/// * `maxiter` - The maximal number of iterations
133135/// * `order` - Whether to solve for the largest or lowest eigenvalues
134136///
135- /// The function returns an `EigResult ` with the eigenvalue/eigenvector and achieved residual norm
137+ /// The function returns an `LobpcgResult ` with the eigenvalue/eigenvector and achieved residual norm
136138/// for it. All iterations are tracked and the optimal solution returned. In case of an error a
137- /// special variant `EigResult ::NotConverged` additionally carries the error. This can happen when
139+ /// special variant `LobpcgResult ::NotConverged` additionally carries the error. This can happen when
138140/// the precision of the matrix is too low (switch then from `f32` to `f64` for example).
139- pub fn lobpcg < A : Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default , F : Fn ( ArrayView2 < A > ) -> Array2 < A > , G : Fn ( ArrayViewMut2 < A > ) > (
141+ pub fn lobpcg <
142+ A : Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default ,
143+ F : Fn ( ArrayView2 < A > ) -> Array2 < A > ,
144+ G : Fn ( ArrayViewMut2 < A > ) ,
145+ > (
140146 a : F ,
141147 mut x : Array2 < A > ,
142148 m : G ,
143149 y : Option < Array2 < A > > ,
144150 tol : A :: Real ,
145151 maxiter : usize ,
146152 order : Order ,
147- ) -> EigResult < A > {
153+ ) -> LobpcgResult < A > {
148154 // the initital approximation should be maximal square
149155 // n is the dimensionality of the problem
150156 let ( n, size_x) = ( x. nrows ( ) , x. ncols ( ) ) ;
@@ -172,7 +178,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
172178 // orthonormalize the initial guess
173179 let ( x, _) = match orthonormalize ( x) {
174180 Ok ( x) => x,
175- Err ( err) => return EigResult :: NoResult ( err) ,
181+ Err ( err) => return LobpcgResult :: NoResult ( err) ,
176182 } ;
177183
178184 // calculate AX and XAX for Rayleigh quotient
@@ -182,7 +188,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
182188 // perform eigenvalue decomposition of XAX
183189 let ( mut lambda, eig_block) = match sorted_eig ( xax. view ( ) , None , size_x, & order) {
184190 Ok ( x) => x,
185- Err ( err) => return EigResult :: NoResult ( err) ,
191+ Err ( err) => return LobpcgResult :: NoResult ( err) ,
186192 } ;
187193
188194 // initiate approximation of the eigenvector
@@ -219,12 +225,20 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
219225
220226 // compare best result and update if we improved
221227 let sum_rnorm: A :: Real = residual_norms. iter ( ) . cloned ( ) . sum ( ) ;
222- if best_result. as_ref ( ) . map ( |x : & ( _ , _ , Vec < A :: Real > ) | x. 2 . iter ( ) . cloned ( ) . sum :: < A :: Real > ( ) > sum_rnorm) . unwrap_or ( true ) {
228+ if best_result
229+ . as_ref ( )
230+ . map ( |x : & ( _ , _ , Vec < A :: Real > ) | x. 2 . iter ( ) . cloned ( ) . sum :: < A :: Real > ( ) > sum_rnorm)
231+ . unwrap_or ( true )
232+ {
223233 best_result = Some ( ( lambda. clone ( ) , x. clone ( ) , residual_norms. clone ( ) ) ) ;
224234 }
225235
226236 // disable eigenvalues which are below the tolerance threshold
227- activemask = residual_norms. iter ( ) . zip ( activemask. iter ( ) ) . map ( |( x, a) | * x > tol && * a) . collect ( ) ;
237+ activemask = residual_norms
238+ . iter ( )
239+ . zip ( activemask. iter ( ) )
240+ . map ( |( x, a) | * x > tol && * a)
241+ . collect ( ) ;
228242
229243 // resize identity block if necessary
230244 let current_block_size = activemask. iter ( ) . filter ( |x| * * x) . count ( ) ;
@@ -279,23 +293,19 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
279293 rar = ( & rar + & rar. t ( ) ) / two;
280294 let xax = x. t ( ) . dot ( & ax) ;
281295
282- (
283- ( & xax + & xax. t ( ) ) / two,
284- x. t ( ) . dot ( & x) ,
285- r. t ( ) . dot ( & r) ,
286- x. t ( ) . dot ( & r)
287- )
296+ ( ( & xax + & xax. t ( ) ) / two, x. t ( ) . dot ( & x) , r. t ( ) . dot ( & r) , x. t ( ) . dot ( & r) )
288297 } else {
289298 (
290299 lambda_diag,
291300 ident0. clone ( ) ,
292301 ident. clone ( ) ,
293- Array2 :: zeros ( ( size_x, current_block_size) )
302+ Array2 :: zeros ( ( size_x, current_block_size) ) ,
294303 )
295304 } ;
296305
297306 // mask and orthonormalize P and AP
298- let p_ap = previous_p_ap. as_ref ( )
307+ let p_ap = previous_p_ap
308+ . as_ref ( )
299309 . and_then ( |( p, ap) | {
300310 let active_p = ndarray_mask ( p. view ( ) , & activemask) ;
301311 let active_ap = ndarray_mask ( ap. view ( ) , & activemask) ;
@@ -318,10 +328,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
318328 let xp = x. t ( ) . dot ( active_p) ;
319329 let rp = r. t ( ) . dot ( active_p) ;
320330 let ( pap, pp) = if explicit_gram_flag {
321- (
322- ( & pap + & pap. t ( ) ) / two,
323- active_p. t ( ) . dot ( active_p)
324- )
331+ ( ( & pap + & pap. t ( ) ) / two, active_p. t ( ) . dot ( active_p) )
325332 } else {
326333 ( pap, ident. clone ( ) )
327334 } ;
@@ -342,16 +349,8 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
342349 )
343350 } else {
344351 (
345- stack ! [
346- Axis ( 0 ) ,
347- stack![ Axis ( 1 ) , xax, xar] ,
348- stack![ Axis ( 1 ) , xar. t( ) , rar]
349- ] ,
350- stack ! [
351- Axis ( 0 ) ,
352- stack![ Axis ( 1 ) , xx, xr] ,
353- stack![ Axis ( 1 ) , xr. t( ) , rr]
354- ] ,
352+ stack ! [ Axis ( 0 ) , stack![ Axis ( 1 ) , xax, xar] , stack![ Axis ( 1 ) , xar. t( ) , rar] ] ,
353+ stack ! [ Axis ( 0 ) , stack![ Axis ( 1 ) , xx, xr] , stack![ Axis ( 1 ) , xr. t( ) , rr] ] ,
355354 )
356355 } ;
357356
@@ -363,16 +362,16 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
363362 if previous_p_ap. is_some ( ) {
364363 previous_p_ap = None ;
365364 continue ;
366- } else { // or break if restart is not possible
365+ } else {
366+ // or break if restart is not possible
367367 break Err ( err) ;
368368 }
369369 }
370370 } ;
371371 lambda = new_lambda;
372372
373373 // approximate eigenvector X and conjugate vectors P with solution of eigenproblem
374- let ( p, ap, tau) = if let Some ( ( active_p, active_ap) ) = p_ap
375- {
374+ let ( p, ap, tau) = if let Some ( ( active_p, active_ap) ) = p_ap {
376375 // tau are eigenvalues to basis of X
377376 let tau = eig_vecs. slice ( s ! [ ..size_x, ..] ) ;
378377 // alpha are eigenvalues to basis of R
@@ -414,8 +413,8 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
414413 dbg ! ( & residual_norms_history) ;
415414
416415 match final_norm {
417- Ok ( _) => EigResult :: Ok ( vals, vecs, rnorm) ,
418- Err ( err) => EigResult :: Err ( vals, vecs, rnorm, err)
416+ Ok ( _) => LobpcgResult :: Ok ( vals, vecs, rnorm) ,
417+ Err ( err) => LobpcgResult :: Err ( vals, vecs, rnorm, err) ,
419418 }
420419}
421420
@@ -425,7 +424,7 @@ mod tests {
425424 use super :: ndarray_mask;
426425 use super :: orthonormalize;
427426 use super :: sorted_eig;
428- use super :: EigResult ;
427+ use super :: LobpcgResult ;
429428 use super :: Order ;
430429 use crate :: close_l2;
431430 use crate :: qr:: * ;
@@ -486,7 +485,7 @@ mod tests {
486485 let result = lobpcg ( |y| a. dot ( & y) , x, |_| { } , None , 1e-5 , n * 2 , order) ;
487486 dbg ! ( & result) ;
488487 match result {
489- EigResult :: Ok ( vals, _, r_norms) | EigResult :: Err ( vals, _, r_norms, _) => {
488+ LobpcgResult :: Ok ( vals, _, r_norms) | LobpcgResult :: Err ( vals, _, r_norms, _) => {
490489 // check convergence
491490 for ( i, norm) in r_norms. into_iter ( ) . enumerate ( ) {
492491 if norm > 1e-5 {
@@ -501,7 +500,7 @@ mod tests {
501500 close_l2 ( & Array1 :: from ( ground_truth_eigvals. to_vec ( ) ) , & vals, num as f64 * 5e-4 )
502501 }
503502 }
504- EigResult :: NoResult ( err) => panic ! ( "Did not converge: {:?}" , err) ,
503+ LobpcgResult :: NoResult ( err) => panic ! ( "Did not converge: {:?}" , err) ,
505504 }
506505 }
507506
@@ -539,11 +538,15 @@ mod tests {
539538 let diag = arr1 ( & [ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. , 10. ] ) ;
540539 let a = Array2 :: from_diag ( & diag) ;
541540 let x: Array2 < f64 > = Array2 :: random ( ( 10 , 1 ) , Uniform :: new ( 0.0 , 1.0 ) ) ;
542- let y: Array2 < f64 > = arr2 ( & [ [ 1.0 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ] , [ 0. , 1.0 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ] ] ) . reversed_axes ( ) ;
541+ let y: Array2 < f64 > = arr2 ( & [
542+ [ 1.0 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ] ,
543+ [ 0. , 1.0 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ] ,
544+ ] )
545+ . reversed_axes ( ) ;
543546
544547 let result = lobpcg ( |y| a. dot ( & y) , x, |_| { } , Some ( y) , 1e-10 , 50 , Order :: Smallest ) ;
545548 match result {
546- EigResult :: Ok ( vals, vecs, r_norms) | EigResult :: Err ( vals, vecs, r_norms, _) => {
549+ LobpcgResult :: Ok ( vals, vecs, r_norms) | LobpcgResult :: Err ( vals, vecs, r_norms, _) => {
547550 // check convergence
548551 for ( i, norm) in r_norms. into_iter ( ) . enumerate ( ) {
549552 if norm > 0.01 {
@@ -561,7 +564,7 @@ mod tests {
561564 1e-5 ,
562565 ) ;
563566 }
564- EigResult :: NoResult ( err) => panic ! ( "Did not converge: {:?}" , err) ,
567+ LobpcgResult :: NoResult ( err) => panic ! ( "Did not converge: {:?}" , err) ,
565568 }
566569 }
567570}
0 commit comments