Skip to content

Commit 9c21d78

Browse files
committed
Implement iterator for TruncatedEig
1 parent 90c2c35 commit 9c21d78

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

src/lobpcg/lobpcg.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,27 @@ fn apply_constraints<A: Scalar + Lapack>(
8888
y: ArrayView2<A>,
8989
) {
9090
let gram_yv = y.t().dot(&v);
91+
dbg!(&gram_yv.shape());
9192

9293
let u = gram_yv
93-
.genrows()
94+
.gencolumns()
9495
.into_iter()
95-
.map(|x| fact_yy.solvec(&x).unwrap().to_vec())
96+
.map(|x| {
97+
dbg!(&x.shape());
98+
let res = fact_yy.solvec(&x).unwrap();
99+
100+
dbg!(&res);
101+
102+
res.to_vec()
103+
})
96104
.flatten()
97105
.collect::<Vec<A>>();
98106

99107
let rows = gram_yv.len_of(Axis(0));
100108
let u = Array2::from_shape_vec((rows, u.len() / rows), u).unwrap();
109+
dbg!(&u);
110+
dbg!(y.shape());
111+
dbg!(&v.shape());
101112

102113
v -= &(y.dot(&u));
103114
}
@@ -173,16 +184,14 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
173184
// factorize yy for later use
174185
let fact_yy = match y {
175186
Some(ref y) => {
176-
let fact_yy = y.t().dot(y).factorizec(UPLO::Upper).unwrap();
187+
let fact_yy = y.t().dot(y).factorizec(UPLO::Lower).unwrap();
177188

178189
apply_constraints(x.view_mut(), &fact_yy, y.view());
179190
Some(fact_yy)
180191
},
181192
None => None
182193
};
183194

184-
185-
186195
// orthonormalize the initial guess and calculate matrices AX and XAX
187196
let (x, _) = match orthonormalize(x) {
188197
Ok(x) => x,
@@ -362,7 +371,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
362371
iter -= 1;
363372
};
364373

365-
dbg!(&residual_norms);
374+
//dbg!(&residual_norms);
366375
let best_idx = residual_norms.iter().enumerate().min_by(
367376
|&(_, item1): &(usize, &Vec<A::Real>), &(_, item2): &(usize, &Vec<A::Real>)| {
368377
let norm1: A::Real = item1.iter().map(|x| (*x) * (*x)).sum();
@@ -452,7 +461,7 @@ mod tests {
452461
let n = a.len_of(Axis(0));
453462
let x: Array2<f64> = Array2::random((n, num), Uniform::new(0.0, 1.0));
454463

455-
let result = lobpcg(|y| a.dot(&y), x, None, None, 1e-10, n, order);
464+
let result = lobpcg(|y| a.dot(&y), x, None, None, 1e-10, 2 * n, order);
456465
match result {
457466
EigResult::Ok(vals, _, r_norms) | EigResult::Err(vals, _, r_norms, _) => {
458467
// check convergence

0 commit comments

Comments
 (0)