Skip to content

Commit 6c94817

Browse files
committed
Add restarting and improve performance with explicit gram flag
1 parent cea3c0e commit 6c94817

File tree

3 files changed

+94
-45
lines changed

3 files changed

+94
-45
lines changed

src/lobpcg/eig.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::{Lapack, Scalar};
44
///
55
use ndarray::prelude::*;
66
use ndarray::stack;
7+
use ndarray::ScalarOperand;
78
use ndarray_rand::rand_distr::Uniform;
89
use ndarray_rand::RandomExt;
910
use num_traits::{Float, NumCast};
@@ -23,7 +24,7 @@ pub struct TruncatedEig<A: Scalar> {
2324
maxiter: usize,
2425
}
2526

26-
impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedEig<A> {
27+
impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> TruncatedEig<A> {
2728
pub fn new(problem: Array2<A>, order: Order) -> TruncatedEig<A> {
2829
TruncatedEig {
2930
precision: NumCast::from(1e-5).unwrap(),
@@ -88,7 +89,7 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedEig<A> {
8889
}
8990
}
9091

91-
impl<A: Float + Scalar + Lapack + PartialOrd + Default> IntoIterator for TruncatedEig<A> {
92+
impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> IntoIterator for TruncatedEig<A> {
9293
type Item = (Array1<A>, Array2<A>);
9394
type IntoIter = TruncatedEigIterator<A>;
9495

@@ -111,7 +112,7 @@ pub struct TruncatedEigIterator<A: Scalar> {
111112
eig: TruncatedEig<A>,
112113
}
113114

114-
impl<A: Float + Scalar + Lapack + PartialOrd + Default> Iterator for TruncatedEigIterator<A> {
115+
impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Iterator for TruncatedEigIterator<A> {
115116
type Item = (Array1<A>, Array2<A>);
116117

117118
fn next(&mut self) -> Option<Self::Item> {

src/lobpcg/lobpcg.rs

Lines changed: 88 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use crate::{cholesky::*, close_l2, eigh::*, norm::*, triangular::*};
77
use crate::{Lapack, Scalar};
88
use ndarray::prelude::*;
99
use ndarray::OwnedRepr;
10-
use num_traits::NumCast;
10+
use ndarray::ScalarOperand;
11+
use num_traits::{NumCast, Float};
1112

1213
/// Find largest or smallest eigenvalues
1314
#[derive(Debug, Clone)]
@@ -136,7 +137,7 @@ fn orthonormalize<T: Scalar + Lapack>(v: Array2<T>) -> Result<(Array2<T>, Array2
136137
/// for it. All iterations are tracked and the optimal solution returned. In case of an error a
137138
/// special variant `EigResult::NotConverged` additionally carries the error. This can happen when
138139
/// the precision of the matrix is too low (switch from `f32` to `f64` for example).
139-
pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) -> Array2<A>, G: Fn(ArrayViewMut2<A>)>(
140+
pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default, F: Fn(ArrayView2<A>) -> Array2<A>, G: Fn(ArrayViewMut2<A>)>(
140141
a: F,
141142
mut x: Array2<A>,
142143
m: G,
@@ -193,15 +194,17 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
193194
let mut ax = ax.dot(&eig_block);
194195

195196
let mut activemask = vec![true; size_x];
196-
let mut residual_norms = Vec::new();
197+
let mut residual_norms_history = Vec::new();
197198
let mut best_result = None;
198199

199200
let mut previous_block_size = size_x;
200201

201202
let mut ident: Array2<A> = Array2::eye(size_x);
202203
let ident0: Array2<A> = Array2::eye(size_x);
204+
let two: A = NumCast::from(2.0).unwrap();
203205

204206
let mut ap: Option<(Array2<A>, Array2<A>)> = None;
207+
let mut explicit_gram_flag = false;
205208

206209
let final_norm = loop {
207210
// calculate residual
@@ -212,17 +215,17 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
212215
let r = &ax - &lambda_x;
213216

214217
// calculate L2 norm of error for every eigenvalue
215-
let tmp = r.gencolumns().into_iter().map(|x| x.norm()).collect::<Vec<A::Real>>();
216-
residual_norms.push(tmp.clone());
218+
let residual_norms = r.gencolumns().into_iter().map(|x| x.norm()).collect::<Vec<A::Real>>();
219+
residual_norms_history.push(residual_norms.clone());
217220

218221
// compare best result and update if we improved
219-
let sum_rnorm: A::Real = tmp.iter().cloned().sum();
222+
let sum_rnorm: A::Real = residual_norms.iter().cloned().sum();
220223
if best_result.as_ref().map(|x: &(_,_,Vec<A::Real>)| x.2.iter().cloned().sum::<A::Real>() > sum_rnorm).unwrap_or(true) {
221-
best_result = Some((lambda.clone(), x.clone(), tmp.clone()));
224+
best_result = Some((lambda.clone(), x.clone(), residual_norms.clone()));
222225
}
223226

224227
// disable eigenvalues which are below the tolerance threshold
225-
activemask = tmp.iter().zip(activemask.iter()).map(|(x, a)| *x > tol && *a).collect();
228+
activemask = residual_norms.iter().zip(activemask.iter()).map(|(x, a)| *x > tol && *a).collect();
226229

227230
// resize identity block if necessary
228231
let current_block_size = activemask.iter().filter(|x| **x).count();
@@ -234,17 +237,19 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
234237
// if we are below the threshold for all eigenvalue or exceeded the number of iteration,
235238
// abort
236239
if current_block_size == 0 || iter == 0 {
237-
break Ok(tmp);
240+
break Ok(residual_norms);
238241
}
239242

240243
// select active eigenvalues, apply pre-conditioner, orthogonalize to Y and orthonormalize
241244
let mut active_block_r = ndarray_mask(r.view(), &activemask);
242245
// apply preconditioner
243246
m(active_block_r.view_mut());
244-
// apply constraints
247+
// apply constraints to the preconditioned residuals
245248
if let (Some(ref y), Some(ref fact_yy)) = (&y, &fact_yy) {
246249
apply_constraints(active_block_r.view_mut(), fact_yy, y.view());
247250
}
251+
// orthogonalize the preconditioned residual to x
252+
active_block_r -= &x.dot(&x.t().dot(&active_block_r));
248253

249254
let (r, _) = match orthonormalize(active_block_r) {
250255
Ok(x) => x,
@@ -253,57 +258,99 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
253258

254259
let ar = a(r.view());
255260

261+
let max_rnorm = if A::epsilon() > NumCast::from(1e-8).unwrap() {
262+
NumCast::from(1.0).unwrap()
263+
} else {
264+
NumCast::from(1.0e-8).unwrap()
265+
};
266+
267+
// if we once below the max_rnorm enable explicit gram flag
268+
let max_norm = residual_norms.into_iter().fold(A::Real::neg_infinity(), A::Real::max);
269+
270+
explicit_gram_flag = max_norm <= max_rnorm || explicit_gram_flag;
271+
256272
// perform the Rayleigh Ritz procedure
257-
let xaw = x.t().dot(&ar);
258-
let waw = r.t().dot(&ar);
259-
let xw = x.t().dot(&r);
273+
let xar = x.t().dot(&ar);
274+
let mut rar = r.t().dot(&ar);
260275

261-
// compute symmetric gram matrices
262-
let (gram_a, gram_b, active_p, active_ap) = if let Some((ref p, ref ap)) = ap {
276+
let (xax, xx, rr, xr) = if explicit_gram_flag {
277+
rar = (&rar + &rar.t()) / two;
278+
let xax = x.t().dot(&ax);
279+
280+
(
281+
(&xax + &xax.t()) / two,
282+
x.t().dot(&x),
283+
r.t().dot(&r),
284+
x.t().dot(&r)
285+
)
286+
} else {
287+
(
288+
lambda_diag,
289+
ident0.clone(),
290+
ident.clone(),
291+
Array2::zeros((size_x, current_block_size))
292+
)
293+
};
294+
295+
let p_ap: Option<(_,_)> = ap.as_ref().and_then(|(p, ap)| {
263296
let active_p = ndarray_mask(p.view(), &activemask);
264297
let active_ap = ndarray_mask(ap.view(), &activemask);
265298

266-
let (active_p, p_r) = orthonormalize(active_p).unwrap();
299+
if let Ok((active_p, p_r)) = orthonormalize(active_p) {
300+
if let Ok(active_ap) = p_r.solve_triangular(UPLO::Lower, Diag::NonUnit, &active_ap.reversed_axes()) {
301+
let active_ap = active_ap.reversed_axes();
302+
303+
Some((active_p, active_ap))
304+
} else {
305+
None
306+
}
307+
} else {
308+
None
309+
}
310+
});
267311

268-
let active_ap = match p_r.solve_triangular(UPLO::Lower, Diag::NonUnit, &active_ap.reversed_axes()) {
269-
Ok(x) => x,
270-
Err(err) => break Err(err),
312+
// compute symmetric gram matrices
313+
let (gram_a, gram_b) = if let Some((active_p, active_ap)) = &p_ap {
314+
let xap = x.t().dot(active_ap);
315+
let rap = r.t().dot(active_ap);
316+
let pap = active_p.t().dot(active_ap);
317+
let xp = x.t().dot(active_p);
318+
let rp = r.t().dot(active_p);
319+
let (pap, pp) = if explicit_gram_flag {
320+
(
321+
(&pap + &pap.t()) / two,
322+
active_p.t().dot(active_p)
323+
)
324+
} else {
325+
(pap, ident.clone())
271326
};
272327

273-
let active_ap = active_ap.reversed_axes();
274-
275-
let xap = x.t().dot(&active_ap);
276-
let wap = r.t().dot(&active_ap);
277-
let pap = active_p.t().dot(&active_ap);
278-
let xp = x.t().dot(&active_p);
279-
let wp = r.t().dot(&active_p);
280-
281328
(
282329
stack![
283330
Axis(0),
284-
stack![Axis(1), lambda_diag, xaw, xap],
285-
stack![Axis(1), xaw.t(), waw, wap],
286-
stack![Axis(1), xap.t(), wap.t(), pap]
331+
stack![Axis(1), xax, xar, xap],
332+
stack![Axis(1), xar.t(), rar, rap],
333+
stack![Axis(1), xap.t(), rap.t(), pap]
287334
],
288335
stack![
289336
Axis(0),
290-
stack![Axis(1), ident0, xw, xp],
291-
stack![Axis(1), xw.t(), ident, wp],
292-
stack![Axis(1), xp.t(), wp.t(), ident]
337+
stack![Axis(1), xx, xr, xp],
338+
stack![Axis(1), xr.t(), rr, rp],
339+
stack![Axis(1), xp.t(), rp.t(), pp]
293340
],
294-
Some(active_p),
295-
Some(active_ap),
296341
)
297342
} else {
298343
(
299344
stack![
300345
Axis(0),
301-
stack![Axis(1), lambda_diag, xaw],
302-
stack![Axis(1), xaw.t(), waw]
346+
stack![Axis(1), xax, xar],
347+
stack![Axis(1), xar.t(), rar]
348+
],
349+
stack![
350+
Axis(0),
351+
stack![Axis(1), xx, xr],
352+
stack![Axis(1), xr.t(), rr]
303353
],
304-
stack![Axis(0), stack![Axis(1), ident0, xw], stack![Axis(1), xw.t(), ident]],
305-
None,
306-
None,
307354
)
308355
};
309356

@@ -313,7 +360,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
313360
};
314361
lambda = new_lambda;
315362

316-
let (pp, app, eig_x) = if let (Some(_), (Some(ref active_p), Some(ref active_ap))) = (ap, (active_p, active_ap))
363+
let (pp, app, eig_x) = if let (Some(_), Some((active_p, active_ap))) = (ap, p_ap)
317364
{
318365
let eig_x = eig_vecs.slice(s![..size_x, ..]);
319366
let eig_r = eig_vecs.slice(s![size_x..size_x + current_block_size, ..]);

src/lobpcg/svd.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use super::lobpcg::{lobpcg, EigResult, Order};
55
use crate::error::Result;
66
use crate::{Lapack, Scalar};
77
use ndarray::prelude::*;
8+
use ndarray::ScalarOperand;
89
use ndarray_rand::rand_distr::Uniform;
910
use ndarray_rand::RandomExt;
1011
use num_traits::{Float, NumCast};
@@ -97,7 +98,7 @@ pub struct TruncatedSvd<A: Scalar> {
9798
maxiter: usize,
9899
}
99100

100-
impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
101+
impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> TruncatedSvd<A> {
101102
pub fn new(problem: Array2<A>, order: Order) -> TruncatedSvd<A> {
102103
TruncatedSvd {
103104
precision: NumCast::from(1e-5).unwrap(),

0 commit comments

Comments
 (0)