Skip to content

Commit a7ce0e3

Browse files
committed
Use select for ndarray_mask
1 parent 68d8ec7 commit a7ce0e3

File tree

1 file changed

+14
-41
lines changed

1 file changed

+14
-41
lines changed

src/lobpcg/lobpcg.rs

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,13 @@ fn sorted_eig<A: Scalar + Lapack>(
5858

5959
/// Masks a matrix with the given `matrix`
6060
fn ndarray_mask<A: Scalar>(matrix: ArrayView2<A>, mask: &[bool]) -> Array2<A> {
61-
let (rows, cols) = (matrix.nrows(), matrix.ncols());
61+
assert_eq!(mask.len(), matrix.ncols());
6262

63-
assert_eq!(mask.len(), cols);
63+
let indices = (0..mask.len()).zip(mask.into_iter())
64+
.filter(|(_,b)| **b).map(|(a,_)| a)
65+
.collect::<Vec<usize>>();
6466

65-
let n_positive = mask.iter().filter(|x| **x).count();
66-
67-
let matrix = matrix
68-
.gencolumns()
69-
.into_iter()
70-
.zip(mask.iter())
71-
.filter(|(_, x)| **x)
72-
.map(|(x, _)| x.to_vec())
73-
.flatten()
74-
.collect::<Vec<A>>();
75-
76-
Array2::from_shape_vec((n_positive, rows), matrix)
77-
.unwrap()
78-
.reversed_axes()
67+
matrix.select(Axis(1), &indices)
7968
}
8069

8170
/// Applies constraints ensuring that a matrix is orthogonal to it
@@ -193,19 +182,16 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
193182
let ax = a(x.view());
194183
let xax = x.t().dot(&ax);
195184

196-
// perform eigenvalue decomposition on XAX
185+
// perform eigenvalue decomposition of XAX as initialization
197186
let (mut lambda, eig_block) = match sorted_eig(xax.view(), None, size_x, &order) {
198187
Ok(x) => x,
199188
Err(err) => return EigResult::NoResult(err),
200189
};
201190

202-
//dbg!(&lambda, &eig_block);
203-
204191
// initiate X and AX with eigenvector
205192
let mut x = x.dot(&eig_block);
206193
let mut ax = ax.dot(&eig_block);
207194

208-
//dbg!(&X, &AX);
209195
let mut activemask = vec![true; size_x];
210196
let mut residual_norms = Vec::new();
211197
let mut results = vec![(lambda.clone(), x.clone())];
@@ -219,10 +205,11 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
219205

220206
let final_norm = loop {
221207
// calculate residual
222-
let lambda_tmp = lambda.clone().insert_axis(Axis(0));
223-
let tmp = &x * &lambda_tmp;
208+
let lambda_diag = Array2::from_diag(&lambda);
209+
let lambda_x = x.dot(&lambda_diag);
224210

225-
let r = &ax - &tmp;
211+
// calculate residual AX - lambdaX
212+
let r = &ax - &lambda_x;
226213

227214
// calculate L2 norm of error for every eigenvalue
228215
let tmp = r.gencolumns().into_iter().map(|x| x.norm()).collect::<Vec<A::Real>>();
@@ -248,7 +235,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
248235
let mut active_block_r = ndarray_mask(r.view(), &activemask);
249236
// apply preconditioner
250237
m(active_block_r.view_mut());
251-
238+
// apply constraints
252239
if let (Some(ref y), Some(ref fact_yy)) = (&y, &fact_yy) {
253240
apply_constraints(active_block_r.view_mut(), fact_yy, y.view());
254241
}
@@ -271,17 +258,14 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
271258
let active_ap = ndarray_mask(ap.view(), &activemask);
272259

273260
let (active_p, p_r) = orthonormalize(active_p).unwrap();
274-
//dbg!(&active_P, &P_R);
261+
275262
let active_ap = match p_r.solve_triangular(UPLO::Lower, Diag::NonUnit, &active_ap.reversed_axes()) {
276263
Ok(x) => x,
277264
Err(err) => break Err(err),
278265
};
279266

280267
let active_ap = active_ap.reversed_axes();
281268

282-
//dbg!(&active_AP);
283-
//dbg!(&R);
284-
285269
let xap = x.t().dot(&active_ap);
286270
let wap = r.t().dot(&active_ap);
287271
let pap = active_p.t().dot(&active_ap);
@@ -291,7 +275,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
291275
(
292276
stack![
293277
Axis(0),
294-
stack![Axis(1), Array2::from_diag(&lambda), xaw, xap],
278+
stack![Axis(1), lambda_diag, xaw, xap],
295279
stack![Axis(1), xaw.t(), waw, wap],
296280
stack![Axis(1), xap.t(), wap.t(), pap]
297281
],
@@ -308,7 +292,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
308292
(
309293
stack![
310294
Axis(0),
311-
stack![Axis(1), Array2::from_diag(&lambda), xaw],
295+
stack![Axis(1), lambda_diag, xaw],
312296
stack![Axis(1), xaw.t(), waw]
313297
],
314298
stack![Axis(0), stack![Axis(1), ident0, xw], stack![Axis(1), xw.t(), ident]],
@@ -323,25 +307,15 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
323307
};
324308
lambda = new_lambda;
325309

326-
//dbg!(&lambda, &eig_vecs);
327310
let (pp, app, eig_x) = if let (Some(_), (Some(ref active_p), Some(ref active_ap))) = (ap, (active_p, active_ap))
328311
{
329312
let eig_x = eig_vecs.slice(s![..size_x, ..]);
330313
let eig_r = eig_vecs.slice(s![size_x..size_x + current_block_size, ..]);
331314
let eig_p = eig_vecs.slice(s![size_x + current_block_size.., ..]);
332315

333-
//dbg!(&eig_X);
334-
//dbg!(&eig_R);
335-
//dbg!(&eig_P);
336-
337-
//dbg!(&R, &AR, &active_P, &active_AP);
338-
339316
let pp = r.dot(&eig_r) + active_p.dot(&eig_p);
340317
let app = ar.dot(&eig_r) + active_ap.dot(&eig_p);
341318

342-
//dbg!(&pp);
343-
//dbg!(&app);
344-
345319
(pp, app, eig_x)
346320
} else {
347321
let eig_x = eig_vecs.slice(s![..size_x, ..]);
@@ -363,7 +337,6 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
363337
iter -= 1;
364338
};
365339

366-
//dbg!(&residual_norms);
367340
let best_idx = residual_norms.iter().enumerate().min_by(
368341
|&(_, item1): &(usize, &Vec<A::Real>), &(_, item2): &(usize, &Vec<A::Real>)| {
369342
let norm1: A::Real = item1.iter().map(|x| (*x) * (*x)).sum();

0 commit comments

Comments
 (0)