@@ -264,7 +264,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
264264 NumCast :: from ( 1.0e-8 ) . unwrap ( )
265265 } ;
266266
267- // if we once below the max_rnorm enable explicit gram flag
267+ // if we are once below the max_rnorm, enable explicit gram flag
268268 let max_norm = residual_norms. into_iter ( ) . fold ( A :: Real :: neg_infinity ( ) , A :: Real :: max) ;
269269
270270 explicit_gram_flag = max_norm <= max_rnorm || explicit_gram_flag;
@@ -292,22 +292,19 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
292292 )
293293 } ;
294294
295- let p_ap: Option < ( _ , _ ) > = ap. as_ref ( ) . and_then ( |( p, ap) | {
296- let active_p = ndarray_mask ( p. view ( ) , & activemask) ;
297- let active_ap = ndarray_mask ( ap. view ( ) , & activemask) ;
295+ let p_ap = ap. as_ref ( )
296+ . and_then ( |( p, ap) | {
297+ let active_p = ndarray_mask ( p. view ( ) , & activemask) ;
298+ let active_ap = ndarray_mask ( ap. view ( ) , & activemask) ;
298299
299- if let Ok ( ( active_p, p_r) ) = orthonormalize ( active_p) {
300- if let Ok ( active_ap) = p_r. solve_triangular ( UPLO :: Lower , Diag :: NonUnit , & active_ap. reversed_axes ( ) ) {
301- let active_ap = active_ap. reversed_axes ( ) ;
302-
303- Some ( ( active_p, active_ap) )
304- } else {
305- None
306- }
307- } else {
308- None
309- }
310- } ) ;
300+ orthonormalize ( active_p) . map ( |x| ( active_ap, x) ) . ok ( )
301+ } )
302+ . and_then ( |( active_ap, ( active_p, p_r) ) | {
303+ let active_ap = active_ap. reversed_axes ( ) ;
304+ p_r. solve_triangular ( UPLO :: Lower , Diag :: NonUnit , & active_ap)
305+ . map ( |active_ap| ( active_p, active_ap. reversed_axes ( ) ) )
306+ . ok ( )
307+ } ) ;
311308
312309 // compute symmetric gram matrices
313310 let ( gram_a, gram_b) = if let Some ( ( active_p, active_ap) ) = & p_ap {
@@ -356,7 +353,15 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
356353
357354 let ( new_lambda, eig_vecs) = match sorted_eig ( gram_a. view ( ) , Some ( gram_b. view ( ) ) , size_x, & order) {
358355 Ok ( x) => x,
359- Err ( err) => break Err ( err) ,
356+ Err ( err) => {
357+ // restart if the eigproblem decomposition failed
358+ if ap. is_some ( ) {
359+ ap = None ;
360+ continue ;
361+ } else {
362+ break Err ( err) ;
363+ }
364+ }
360365 } ;
361366 lambda = new_lambda;
362367
0 commit comments