Skip to content

Commit 8188649

Browse files
committed
Implement IntoIterator for TruncatedEig
1 parent 9c21d78 commit 8188649

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

src/lobpcg/eig.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
///! Implements truncated eigenvalue decomposition
2+
///
3+
4+
use ndarray::prelude::*;
5+
use ndarray::stack;
6+
use ndarray_rand::rand_distr::Uniform;
7+
use ndarray_rand::RandomExt;
8+
use num_traits::{Float, NumCast};
9+
use crate::{Scalar, Lapack};
10+
use super::lobpcg::{lobpcg, EigResult, Order};
11+
12+
pub struct TruncatedEig<A: Scalar> {
13+
order: Order,
14+
problem: Array2<A>,
15+
pub constraints: Option<Array2<A>>,
16+
precision: A::Real,
17+
maxiter: usize
18+
}
19+
20+
impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedEig<A> {
21+
pub fn new(problem: Array2<A>, order: Order) -> TruncatedEig<A> {
22+
TruncatedEig {
23+
precision: NumCast::from(1e-5).unwrap(),
24+
maxiter: problem.len_of(Axis(0)) * 2,
25+
constraints: None,
26+
order,
27+
problem
28+
}
29+
}
30+
31+
pub fn precision(mut self, precision: A::Real) -> Self {
32+
self.precision = precision;
33+
34+
self
35+
}
36+
37+
pub fn maxiter(mut self, maxiter: usize) -> Self {
38+
self.maxiter = maxiter;
39+
40+
self
41+
42+
}
43+
44+
pub fn constraints(mut self, constraints: Array2<A>) -> Self {
45+
self.constraints = Some(constraints);
46+
47+
self
48+
}
49+
50+
pub fn once(&self, num: usize) -> EigResult<A> {
51+
let x = Array2::random((self.problem.len_of(Axis(0)), num), Uniform::new(0.0, 1.0))
52+
.mapv(|x| NumCast::from(x).unwrap());
53+
54+
lobpcg(|y| self.problem.dot(&y), x, None, self.constraints.clone(), self.precision, self.maxiter, self.order.clone())
55+
}
56+
}
57+
58+
impl<A: Float + Scalar + Lapack + PartialOrd + Default> IntoIterator for TruncatedEig<A> {
59+
type Item = (Array1<A>, Array2<A>);
60+
type IntoIter = TruncatedEigIterator<A>;
61+
62+
fn into_iter(self) -> TruncatedEigIterator<A>{
63+
TruncatedEigIterator {
64+
step_size: 1,
65+
eig: self
66+
}
67+
}
68+
}
69+
70+
pub struct TruncatedEigIterator<A: Scalar> {
71+
step_size: usize,
72+
eig: TruncatedEig<A>
73+
}
74+
75+
impl<A: Float + Scalar + Lapack + PartialOrd + Default> Iterator for TruncatedEigIterator<A> {
76+
type Item = (Array1<A>, Array2<A>);
77+
78+
fn next(&mut self) -> Option<Self::Item> {
79+
let res = self.eig.once(self.step_size);
80+
dbg!(&res);
81+
82+
match res {
83+
EigResult::Ok(vals, vecs, norms) | EigResult::Err(vals, vecs, norms, _) => {
84+
// abort if any eigenproblem did not converge
85+
for r_norm in norms {
86+
if r_norm > NumCast::from(0.1).unwrap() {
87+
return None;
88+
}
89+
}
90+
91+
let new_constraints = if let Some(ref constraints) = self.eig.constraints {
92+
let eigvecs_arr = constraints.gencolumns().into_iter()
93+
.chain(vecs.gencolumns().into_iter())
94+
.map(|x| x.insert_axis(Axis(1)))
95+
.collect::<Vec<_>>();
96+
97+
stack(Axis(1), &eigvecs_arr).unwrap()
98+
} else {
99+
vecs.clone()
100+
};
101+
102+
dbg!(&new_constraints);
103+
104+
self.eig.constraints = Some(new_constraints);
105+
106+
Some((vals, vecs))
107+
},
108+
EigResult::NoResult(_) => None
109+
}
110+
}
111+
}
112+
113+
#[cfg(test)]
114+
mod tests {
115+
use super::TruncatedEig;
116+
use super::Order;
117+
use ndarray::Array2;
118+
use ndarray_rand::rand_distr::Uniform;
119+
use ndarray_rand::RandomExt;
120+
121+
#[test]
122+
fn test_truncated_eig() {
123+
let a = Array2::random((50, 50), Uniform::new(0., 1.0));
124+
let a = a.t().dot(&a);
125+
126+
let teig = TruncatedEig::new(a, Order::Largest)
127+
.precision(1e-5)
128+
.maxiter(500);
129+
130+
let res = teig.into_iter().take(3).flat_map(|x| x.0.to_vec()).collect::<Vec<_>>();
131+
dbg!(&res);
132+
panic!("");
133+
}
134+
}

src/lobpcg/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod lobpcg;
2+
mod eig;
3+
4+
pub use lobpcg::{lobpcg, EigResult, Order};
5+
pub use eig::TruncatedEig;

0 commit comments

Comments
 (0)