@@ -204,7 +204,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
204204 let two: A = NumCast :: from ( 2.0 ) . unwrap ( ) ;
205205
206206 let mut ap: Option < ( Array2 < A > , Array2 < A > ) > = None ;
207- let mut explicit_gram_flag = false ;
207+ let mut explicit_gram_flag = true ;
208208
209209 let final_norm = loop {
210210 // calculate residual
@@ -258,16 +258,16 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
258258
259259 let ar = a ( r. view ( ) ) ;
260260
261- let max_rnorm = if A :: epsilon ( ) > NumCast :: from ( 1e-8 ) . unwrap ( ) {
261+ // check whether `A` is of type `f32` or `f64`
262+ let max_rnorm_float = if A :: epsilon ( ) > NumCast :: from ( 1e-8 ) . unwrap ( ) {
262263 NumCast :: from ( 1.0 ) . unwrap ( )
263264 } else {
264265 NumCast :: from ( 1.0e-8 ) . unwrap ( )
265266 } ;
266267
267268 // if we are once below the max_rnorm, enable explicit gram flag
268269 let max_norm = residual_norms. into_iter ( ) . fold ( A :: Real :: neg_infinity ( ) , A :: Real :: max) ;
269-
270- explicit_gram_flag = max_norm <= max_rnorm || explicit_gram_flag;
270+ explicit_gram_flag = max_norm <= max_rnorm_float || explicit_gram_flag;
271271
272272 // perform the Rayleigh Ritz procedure
273273 let xar = x. t ( ) . dot ( & ar) ;
@@ -365,7 +365,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
365365 } ;
366366 lambda = new_lambda;
367367
368- let ( pp, app, eig_x) = if let ( Some ( _ ) , Some ( ( active_p, active_ap) ) ) = ( ap , p_ap)
368+ let ( pp, app, eig_x) = if let Some ( ( active_p, active_ap) ) = p_ap
369369 {
370370 let eig_x = eig_vecs. slice ( s ! [ ..size_x, ..] ) ;
371371 let eig_r = eig_vecs. slice ( s ! [ size_x..size_x + current_block_size, ..] ) ;
@@ -396,6 +396,8 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
396396 let ( vals, vecs, rnorm) = best_result. unwrap ( ) ;
397397 let rnorm = rnorm. into_iter ( ) . map ( |x| Scalar :: from_real ( x) ) . collect ( ) ;
398398
399+ dbg ! ( & residual_norms_history) ;
400+
399401 match final_norm {
400402 Ok ( _) => EigResult :: Ok ( vals, vecs, rnorm) ,
401403 Err ( err) => EigResult :: Err ( vals, vecs, rnorm, err)
@@ -466,7 +468,7 @@ mod tests {
466468 let n = a. len_of ( Axis ( 0 ) ) ;
467469 let x: Array2 < f64 > = Array2 :: random ( ( n, num) , Uniform :: new ( 0.0 , 1.0 ) ) ;
468470
469- let result = lobpcg ( |y| a. dot ( & y) , x, |_| { } , None , 1e-10 , 2 * n, order) ;
471+ let result = lobpcg ( |y| a. dot ( & y) , x, |_| { } , None , 1e-5 , n, order) ;
470472 match result {
471473 EigResult :: Ok ( vals, _, r_norms) | EigResult :: Err ( vals, _, r_norms, _) => {
472474 // check convergence
0 commit comments