Skip to content

Commit 93f8b44

Browse files
committed
Remove unnecessary match blocks
1 parent 233fc8a commit 93f8b44

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/lobpcg/lobpcg.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
204204
let two: A = NumCast::from(2.0).unwrap();
205205

206206
let mut ap: Option<(Array2<A>, Array2<A>)> = None;
207-
let mut explicit_gram_flag = false;
207+
let mut explicit_gram_flag = true;
208208

209209
let final_norm = loop {
210210
// calculate residual
@@ -258,16 +258,16 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
258258

259259
let ar = a(r.view());
260260

261-
let max_rnorm = if A::epsilon() > NumCast::from(1e-8).unwrap() {
261+
// check whether `A` is of type `f32` or `f64`
262+
let max_rnorm_float = if A::epsilon() > NumCast::from(1e-8).unwrap() {
262263
NumCast::from(1.0).unwrap()
263264
} else {
264265
NumCast::from(1.0e-8).unwrap()
265266
};
266267

267268
// if we are once below the max_rnorm, enable explicit gram flag
268269
let max_norm = residual_norms.into_iter().fold(A::Real::neg_infinity(), A::Real::max);
269-
270-
explicit_gram_flag = max_norm <= max_rnorm || explicit_gram_flag;
270+
explicit_gram_flag = max_norm <= max_rnorm_float || explicit_gram_flag;
271271

272272
// perform the Rayleigh Ritz procedure
273273
let xar = x.t().dot(&ar);
@@ -365,7 +365,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
365365
};
366366
lambda = new_lambda;
367367

368-
let (pp, app, eig_x) = if let (Some(_), Some((active_p, active_ap))) = (ap, p_ap)
368+
let (pp, app, eig_x) = if let Some((active_p, active_ap)) = p_ap
369369
{
370370
let eig_x = eig_vecs.slice(s![..size_x, ..]);
371371
let eig_r = eig_vecs.slice(s![size_x..size_x + current_block_size, ..]);
@@ -396,6 +396,8 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
396396
let (vals, vecs, rnorm) = best_result.unwrap();
397397
let rnorm = rnorm.into_iter().map(|x| Scalar::from_real(x)).collect();
398398

399+
dbg!(&residual_norms_history);
400+
399401
match final_norm {
400402
Ok(_) => EigResult::Ok(vals, vecs, rnorm),
401403
Err(err) => EigResult::Err(vals, vecs, rnorm, err)
@@ -466,7 +468,7 @@ mod tests {
466468
let n = a.len_of(Axis(0));
467469
let x: Array2<f64> = Array2::random((n, num), Uniform::new(0.0, 1.0));
468470

469-
let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-10, 2 * n, order);
471+
let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-5, n, order);
470472
match result {
471473
EigResult::Ok(vals, _, r_norms) | EigResult::Err(vals, _, r_norms, _) => {
472474
// check convergence

0 commit comments

Comments
 (0)