Skip to content

Commit 233fc8a

Browse files
committed
Restart the eigenvalue decomposition as well
1 parent 6c94817 commit 233fc8a

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

src/lobpcg/lobpcg.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
264264
NumCast::from(1.0e-8).unwrap()
265265
};
266266

267-
// if we once below the max_rnorm enable explicit gram flag
267+
// if we are once below the max_rnorm, enable explicit gram flag
268268
let max_norm = residual_norms.into_iter().fold(A::Real::neg_infinity(), A::Real::max);
269269

270270
explicit_gram_flag = max_norm <= max_rnorm || explicit_gram_flag;
@@ -292,22 +292,19 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
292292
)
293293
};
294294

295-
let p_ap: Option<(_,_)> = ap.as_ref().and_then(|(p, ap)| {
296-
let active_p = ndarray_mask(p.view(), &activemask);
297-
let active_ap = ndarray_mask(ap.view(), &activemask);
295+
let p_ap = ap.as_ref()
296+
.and_then(|(p, ap)| {
297+
let active_p = ndarray_mask(p.view(), &activemask);
298+
let active_ap = ndarray_mask(ap.view(), &activemask);
298299

299-
if let Ok((active_p, p_r)) = orthonormalize(active_p) {
300-
if let Ok(active_ap) = p_r.solve_triangular(UPLO::Lower, Diag::NonUnit, &active_ap.reversed_axes()) {
301-
let active_ap = active_ap.reversed_axes();
302-
303-
Some((active_p, active_ap))
304-
} else {
305-
None
306-
}
307-
} else {
308-
None
309-
}
310-
});
300+
orthonormalize(active_p).map(|x| (active_ap, x)).ok()
301+
})
302+
.and_then(|(active_ap, (active_p, p_r))| {
303+
let active_ap = active_ap.reversed_axes();
304+
p_r.solve_triangular(UPLO::Lower, Diag::NonUnit, &active_ap)
305+
.map(|active_ap| (active_p, active_ap.reversed_axes()))
306+
.ok()
307+
});
311308

312309
// compute symmetric gram matrices
313310
let (gram_a, gram_b) = if let Some((active_p, active_ap)) = &p_ap {
@@ -356,7 +353,15 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
356353

357354
let (new_lambda, eig_vecs) = match sorted_eig(gram_a.view(), Some(gram_b.view()), size_x, &order) {
358355
Ok(x) => x,
359-
Err(err) => break Err(err),
356+
Err(err) => {
357+
// restart if the eigproblem decomposition failed
358+
if ap.is_some() {
359+
ap = None;
360+
continue;
361+
} else {
362+
break Err(err);
363+
}
364+
}
360365
};
361366
lambda = new_lambda;
362367

0 commit comments

Comments
 (0)