Skip to content

Commit f25a177

Browse files
committed
Only save best result for LOBPCG
1 parent a7ce0e3 commit f25a177

File tree

1 file changed

+12
-24
lines changed

1 file changed

+12
-24
lines changed

src/lobpcg/lobpcg.rs

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
194194

195195
let mut activemask = vec![true; size_x];
196196
let mut residual_norms = Vec::new();
197-
let mut results = vec![(lambda.clone(), x.clone())];
197+
let mut best_result = None;
198198

199199
let mut previous_block_size = size_x;
200200

@@ -215,6 +215,12 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
215215
let tmp = r.gencolumns().into_iter().map(|x| x.norm()).collect::<Vec<A::Real>>();
216216
residual_norms.push(tmp.clone());
217217

218+
// compare best result and update if we improved
219+
let sum_rnorm: A::Real = tmp.iter().cloned().sum();
220+
if best_result.as_ref().map(|x: &(_,_,Vec<A::Real>)| x.2.iter().cloned().sum::<A::Real>() > sum_rnorm).unwrap_or(true) {
221+
best_result = Some((lambda.clone(), x.clone(), tmp.clone()));
222+
}
223+
218224
// disable eigenvalues which are below the tolerance threshold
219225
activemask = tmp.iter().zip(activemask.iter()).map(|(x, a)| *x > tol && *a).collect();
220226

@@ -330,35 +336,17 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
330336
x = x.dot(&eig_x) + &pp;
331337
ax = ax.dot(&eig_x) + &app;
332338

333-
results.push((lambda.clone(), x.clone()));
334-
335339
ap = Some((pp, app));
336340

337341
iter -= 1;
338342
};
339343

340-
let best_idx = residual_norms.iter().enumerate().min_by(
341-
|&(_, item1): &(usize, &Vec<A::Real>), &(_, item2): &(usize, &Vec<A::Real>)| {
342-
let norm1: A::Real = item1.iter().map(|x| (*x) * (*x)).sum();
343-
let norm2: A::Real = item2.iter().map(|x| (*x) * (*x)).sum();
344-
norm1.partial_cmp(&norm2).unwrap()
345-
},
346-
);
344+
let (vals, vecs, rnorm) = best_result.unwrap();
345+
let rnorm = rnorm.into_iter().map(|x| Scalar::from_real(x)).collect();
347346

348-
match best_idx {
349-
Some((idx, norms)) => {
350-
let (vals, vecs) = results[idx].clone();
351-
let norms = norms.iter().map(|x| Scalar::from_real(*x)).collect();
352-
353-
match final_norm {
354-
Ok(_) => EigResult::Ok(vals, vecs, norms),
355-
Err(err) => EigResult::Err(vals, vecs, norms, err),
356-
}
357-
}
358-
None => match final_norm {
359-
Ok(_) => panic!("Not error available!"),
360-
Err(err) => EigResult::NoResult(err),
361-
},
347+
match final_norm {
348+
Ok(_) => EigResult::Ok(vals, vecs, rnorm),
349+
Err(err) => EigResult::Err(vals, vecs, rnorm, err)
362350
}
363351
}
364352

0 commit comments

Comments
 (0)