@@ -6,8 +6,7 @@ use crate::error::{LinalgError, Result};
66use crate :: { cholesky:: * , close_l2, eigh:: * , norm:: * , triangular:: * } ;
77use crate :: { Lapack , Scalar } ;
88use ndarray:: prelude:: * ;
9- use ndarray:: OwnedRepr ;
10- use ndarray:: ScalarOperand ;
9+ use ndarray:: { OwnedRepr , ScalarOperand , Data } ;
1110use num_traits:: { Float , NumCast } ;
1211
1312/// Find largest or smallest eigenvalues
@@ -32,19 +31,19 @@ pub enum LobpcgResult<A> {
3231}
3332
3433/// Solve full eigenvalue problem, sort by `order` and truncate to `size`
35- fn sorted_eig < A : Scalar + Lapack > (
36- a : ArrayView2 < A > ,
37- b : Option < ArrayView2 < A > > ,
34+ fn sorted_eig < S : Data < Elem = A > , A : Scalar + Lapack > (
35+ a : ArrayBase < S , Ix2 > ,
36+ b : Option < ArrayBase < S , Ix2 > > ,
3837 size : usize ,
3938 order : & Order ,
4039) -> Result < ( Array1 < A > , Array2 < A > ) > {
40+ let n = a. len_of ( Axis ( 0 ) ) ;
41+
4142 let ( vals, vecs) = match b {
4243 Some ( b) => ( a, b) . eigh ( UPLO :: Upper ) . map ( |x| ( x. 0 , ( x. 1 ) . 0 ) ) ?,
4344 _ => a. eigh ( UPLO :: Upper ) ?,
4445 } ;
4546
46- let n = a. len_of ( Axis ( 0 ) ) ;
47-
4847 Ok ( match order {
4948 Order :: Largest => (
5049 vals. slice_move ( s ! [ n-size..; -1 ] ) . mapv ( |x| Scalar :: from_real ( x) ) ,
@@ -320,55 +319,60 @@ pub fn lobpcg<
320319 . ok ( )
321320 } ) ;
322321
323- // compute symmetric gram matrices
324- let ( gram_a, gram_b) = if let Some ( ( active_p, active_ap) ) = & p_ap {
325- let xap = x. t ( ) . dot ( active_ap) ;
326- let rap = r. t ( ) . dot ( active_ap) ;
327- let pap = active_p. t ( ) . dot ( active_ap) ;
328- let xp = x. t ( ) . dot ( active_p) ;
329- let rp = r. t ( ) . dot ( active_p) ;
330- let ( pap, pp) = if explicit_gram_flag {
331- ( ( & pap + & pap. t ( ) ) / two, active_p. t ( ) . dot ( active_p) )
332- } else {
333- ( pap, ident. clone ( ) )
334- } ;
322+ // compute symmetric gram matrices and calculate solution of eigenproblem
323+ //
324+ // first try to compute the eigenvalue decomposition of the span{R, X, P},
325+ // if this fails (or the algorithm was restarted), then just use span{R, X}
326+ let result = p_ap. as_ref ( )
327+ . ok_or ( LinalgError :: Lapack { return_code : 1 } )
328+ . and_then ( |( active_p, active_ap) | {
329+ let xap = x. t ( ) . dot ( active_ap) ;
330+ let rap = r. t ( ) . dot ( active_ap) ;
331+ let pap = active_p. t ( ) . dot ( active_ap) ;
332+ let xp = x. t ( ) . dot ( active_p) ;
333+ let rp = r. t ( ) . dot ( active_p) ;
334+ let ( pap, pp) = if explicit_gram_flag {
335+ ( ( & pap + & pap. t ( ) ) / two, active_p. t ( ) . dot ( active_p) )
336+ } else {
337+ ( pap, ident. clone ( ) )
338+ } ;
339+
340+ sorted_eig (
341+ stack ! [
342+ Axis ( 0 ) ,
343+ stack![ Axis ( 1 ) , xax, xar, xap] ,
344+ stack![ Axis ( 1 ) , xar. t( ) , rar, rap] ,
345+ stack![ Axis ( 1 ) , xap. t( ) , rap. t( ) , pap]
346+ ] ,
347+ Some ( stack ! [
348+ Axis ( 0 ) ,
349+ stack![ Axis ( 1 ) , xx, xr, xp] ,
350+ stack![ Axis ( 1 ) , xr. t( ) , rr, rp] ,
351+ stack![ Axis ( 1 ) , xp. t( ) , rp. t( ) , pp]
352+ ] ) ,
353+ size_x,
354+ & order
355+ )
356+ } )
357+ . or_else ( |_| {
358+ sorted_eig (
359+ stack ! [ Axis ( 0 ) , stack![ Axis ( 1 ) , xax, xar] , stack![ Axis ( 1 ) , xar. t( ) , rar] ] ,
360+ Some ( stack ! [ Axis ( 0 ) , stack![ Axis ( 1 ) , xx, xr] , stack![ Axis ( 1 ) , xr. t( ) , rr] ] ) ,
361+ size_x,
362+ & order
363+ )
364+ } ) ;
335365
336- (
337- stack ! [
338- Axis ( 0 ) ,
339- stack![ Axis ( 1 ) , xax, xar, xap] ,
340- stack![ Axis ( 1 ) , xar. t( ) , rar, rap] ,
341- stack![ Axis ( 1 ) , xap. t( ) , rap. t( ) , pap]
342- ] ,
343- stack ! [
344- Axis ( 0 ) ,
345- stack![ Axis ( 1 ) , xx, xr, xp] ,
346- stack![ Axis ( 1 ) , xr. t( ) , rr, rp] ,
347- stack![ Axis ( 1 ) , xp. t( ) , rp. t( ) , pp]
348- ] ,
349- )
350- } else {
351- (
352- stack ! [ Axis ( 0 ) , stack![ Axis ( 1 ) , xax, xar] , stack![ Axis ( 1 ) , xar. t( ) , rar] ] ,
353- stack ! [ Axis ( 0 ) , stack![ Axis ( 1 ) , xx, xr] , stack![ Axis ( 1 ) , xr. t( ) , rr] ] ,
354- )
355- } ;
356366
357- // apply Rayleigh-Ritz method for (A - lambda) to calculate optimal expansion coefficients
358- let ( new_lambda, eig_vecs) = match sorted_eig ( gram_a. view ( ) , Some ( gram_b. view ( ) ) , size_x, & order) {
359- Ok ( x) => x,
360- Err ( err) => {
361- // restart if the eigproblem decomposition failed
362- if previous_p_ap. is_some ( ) {
363- previous_p_ap = None ;
364- continue ;
365- } else {
366- // or break if restart is not possible
367- break Err ( err) ;
368- }
369- }
370- } ;
371- lambda = new_lambda;
367+ // update eigenvalues and eigenvectors (lambda is also used in the next iteration)
368+ let eig_vecs;
369+ match result {
370+ Ok ( ( x, y) ) => {
371+ lambda = x;
372+ eig_vecs = y;
373+ } ,
374+ Err ( x) => break Err ( x)
375+ }
372376
373377 // approximate eigenvector X and conjugate vectors P with solution of eigenproblem
374378 let ( p, ap, tau) = if let Some ( ( active_p, active_ap) ) = p_ap {
0 commit comments