Skip to content

Commit 705d37a

Browse files
committed
Add truncated singular value decomposition module
1 parent 8188649 commit 705d37a

File tree

4 files changed

+186
-20
lines changed

4 files changed

+186
-20
lines changed

src/lobpcg/eig.rs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,17 @@ use num_traits::{Float, NumCast};
99
use crate::{Scalar, Lapack};
1010
use super::lobpcg::{lobpcg, EigResult, Order};
1111

12+
/// Truncated eigenproblem solver
13+
///
14+
/// This struct wraps the LOBPCG algorithm and provides convenient builder-pattern access to
15+
/// parameter like maximal iteration, precision and constraint matrix. Furthermore it allows
16+
/// conversion into a iterative solver where each iteration step yields a new eigenvalue/vector
17+
/// pair.
1218
pub struct TruncatedEig<A: Scalar> {
1319
order: Order,
1420
problem: Array2<A>,
1521
pub constraints: Option<Array2<A>>,
22+
preconditioner: Option<Array2<A>>,
1623
precision: A::Real,
1724
maxiter: usize
1825
}
@@ -22,6 +29,7 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedEig<A> {
2229
TruncatedEig {
2330
precision: NumCast::from(1e-5).unwrap(),
2431
maxiter: problem.len_of(Axis(0)) * 2,
32+
preconditioner: None,
2533
constraints: None,
2634
order,
2735
problem
@@ -41,17 +49,24 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedEig<A> {
4149

4250
}
4351

44-
pub fn constraints(mut self, constraints: Array2<A>) -> Self {
52+
pub fn orthogonal_to(mut self, constraints: Array2<A>) -> Self {
4553
self.constraints = Some(constraints);
4654

4755
self
4856
}
4957

58+
pub fn precondition_with(mut self, preconditioner: Array2<A>) -> Self {
59+
self.preconditioner = Some(preconditioner);
60+
61+
self
62+
}
63+
64+
// calculate the eigenvalues once
5065
pub fn once(&self, num: usize) -> EigResult<A> {
5166
let x = Array2::random((self.problem.len_of(Axis(0)), num), Uniform::new(0.0, 1.0))
5267
.mapv(|x| NumCast::from(x).unwrap());
5368

54-
lobpcg(|y| self.problem.dot(&y), x, None, self.constraints.clone(), self.precision, self.maxiter, self.order.clone())
69+
lobpcg(|y| self.problem.dot(&y), x, self.preconditioner.clone(), self.constraints.clone(), self.precision, self.maxiter, self.order.clone())
5570
}
5671
}
5772

@@ -67,6 +82,10 @@ impl<A: Float + Scalar + Lapack + PartialOrd + Default> IntoIterator for Truncat
6782
}
6883
}
6984

85+
/// Truncate eigenproblem iterator
86+
///
87+
/// This wraps a truncated eigenproblem and provides an iterator where each step yields a new
88+
/// eigenvalue/vector pair. Useful for generating pairs until a certain condition is met.
7089
pub struct TruncatedEigIterator<A: Scalar> {
7190
step_size: usize,
7291
eig: TruncatedEig<A>
@@ -77,7 +96,6 @@ impl<A: Float + Scalar + Lapack + PartialOrd + Default> Iterator for TruncatedEi
7796

7897
fn next(&mut self) -> Option<Self::Item> {
7998
let res = self.eig.once(self.step_size);
80-
dbg!(&res);
8199

82100
match res {
83101
EigResult::Ok(vals, vecs, norms) | EigResult::Err(vals, vecs, norms, _) => {
@@ -88,6 +106,7 @@ impl<A: Float + Scalar + Lapack + PartialOrd + Default> Iterator for TruncatedEi
88106
}
89107
}
90108

109+
// add the new eigenvector to the internal constrain matrix
91110
let new_constraints = if let Some(ref constraints) = self.eig.constraints {
92111
let eigvecs_arr = constraints.gencolumns().into_iter()
93112
.chain(vecs.gencolumns().into_iter())
@@ -99,8 +118,6 @@ impl<A: Float + Scalar + Lapack + PartialOrd + Default> Iterator for TruncatedEi
99118
vecs.clone()
100119
};
101120

102-
dbg!(&new_constraints);
103-
104121
self.eig.constraints = Some(new_constraints);
105122

106123
Some((vals, vecs))
@@ -114,21 +131,22 @@ impl<A: Float + Scalar + Lapack + PartialOrd + Default> Iterator for TruncatedEi
114131
mod tests {
115132
use super::TruncatedEig;
116133
use super::Order;
117-
use ndarray::Array2;
118-
use ndarray_rand::rand_distr::Uniform;
119-
use ndarray_rand::RandomExt;
134+
use ndarray::{arr1, Array2};
120135

121136
#[test]
122137
fn test_truncated_eig() {
123-
let a = Array2::random((50, 50), Uniform::new(0., 1.0));
124-
let a = a.t().dot(&a);
138+
let diag = arr1(&[
139+
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20.,
140+
]);
141+
let a = Array2::from_diag(&diag);
125142

126143
let teig = TruncatedEig::new(a, Order::Largest)
127144
.precision(1e-5)
128145
.maxiter(500);
129146

130147
let res = teig.into_iter().take(3).flat_map(|x| x.0.to_vec()).collect::<Vec<_>>();
131-
dbg!(&res);
132-
panic!("");
148+
let ground_truth = vec![20., 19., 18.];
149+
150+
assert!(ground_truth.into_iter().zip(res.into_iter()).map(|(x,y)| (x-y)*(x-y)).sum::<f64>() < 0.01);
133151
}
134152
}

src/lobpcg/lobpcg.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -511,14 +511,6 @@ mod tests {
511511
check_eigenvalues(&a, Order::Smallest, 5, &[-50.0, -48.0, -46.0, -44.0, -42.0]);
512512
}
513513

514-
#[test]
515-
fn test_eigsolver_convergence() {
516-
let tmp = Array2::random((50, 50), Uniform::new(0.0, 1.0));
517-
let a = tmp.t().dot(&tmp);
518-
519-
check_eigenvalues(&a, Order::Largest, 5, &[]);
520-
}
521-
522514
#[test]
523515
fn test_eigsolver_constrainted() {
524516
let diag = arr1(&[

src/lobpcg/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
mod lobpcg;
22
mod eig;
3+
mod svd;
34

45
pub use lobpcg::{lobpcg, EigResult, Order};
56
pub use eig::TruncatedEig;
7+
pub use svd::TruncatedSvd;

src/lobpcg/svd.rs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
///! Implements truncated singular value decomposition
2+
///
3+
4+
use std::ops::DivAssign;
5+
use ndarray::prelude::*;
6+
use ndarray::stack;
7+
use ndarray_rand::rand_distr::Uniform;
8+
use ndarray_rand::RandomExt;
9+
use num_traits::{Float, NumCast};
10+
use crate::{Scalar, Lapack};
11+
use super::lobpcg::{lobpcg, EigResult, Order};
12+
use crate::error::Result;
13+
14+
#[derive(Debug)]
15+
pub struct TruncatedSvdResult<A> {
16+
eigvals: Array1<A>,
17+
eigvecs: Array2<A>,
18+
problem: Array2<A>,
19+
ngm: bool
20+
}
21+
22+
impl<A: Float + PartialOrd + DivAssign<A> + 'static> TruncatedSvdResult<A> {
23+
fn singular_values_with_indices(&self) -> (Vec<A>, Vec<usize>) {
24+
let mut a = self.eigvals.iter()
25+
.map(|x| if *x < NumCast::from(1e-5).unwrap() { NumCast::from(0.0).unwrap() } else { *x })
26+
.map(|x| x.sqrt())
27+
.enumerate()
28+
.collect::<Vec<_>>();
29+
30+
a.sort_by(|(_,x), (_, y)| x.partial_cmp(&y).unwrap().reverse());
31+
32+
a.into_iter().map(|(a,b)| (b,a)).unzip()
33+
}
34+
35+
pub fn values(&self) -> Vec<A> {
36+
let (values, indices) = self.singular_values_with_indices();
37+
38+
values
39+
}
40+
41+
pub fn values_vecs(&self) -> (Array2<A>, Vec<A>, Array2<A>) {
42+
let (values, indices) = self.singular_values_with_indices();
43+
let n_values = values.iter().filter(|x| **x > NumCast::from(0.0).unwrap()).count();
44+
45+
if self.ngm {
46+
let vlarge = self.eigvecs.select(Axis(1), &indices);
47+
let mut ularge = self.problem.dot(&vlarge);
48+
49+
ularge.gencolumns_mut().into_iter()
50+
.zip(values.iter())
51+
.for_each(|(mut a,b)| a.mapv_inplace(|x| x / *b));
52+
53+
let vhlarge = vlarge.reversed_axes();
54+
55+
(vhlarge, values, ularge)
56+
} else {
57+
let ularge = self.eigvecs.select(Axis(1), &indices);
58+
59+
let mut vlarge = ularge.dot(&self.problem);
60+
vlarge.gencolumns_mut().into_iter()
61+
.zip(values.iter())
62+
.for_each(|(mut a,b)| a.mapv_inplace(|x| x / *b));
63+
let vhlarge = vlarge.reversed_axes();
64+
65+
(vhlarge, values, ularge)
66+
}
67+
}
68+
}
69+
70+
/// Truncated singular value decomposition
71+
///
72+
/// This struct wraps the LOBPCG algorithm and provides convenient builder-pattern access to
73+
/// parameter like maximal iteration, precision and constraint matrix. Furthermore it allows
74+
/// conversion into a iterative solver where each iteration step yields a new eigenvalue/vector
75+
/// pair.
76+
pub struct TruncatedSvd<A: Scalar> {
77+
order: Order,
78+
problem: Array2<A>,
79+
precision: A::Real,
80+
maxiter: usize
81+
}
82+
83+
impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
84+
pub fn new(problem: Array2<A>, order: Order) -> TruncatedSvd<A> {
85+
TruncatedSvd {
86+
precision: NumCast::from(1e-5).unwrap(),
87+
maxiter: problem.len_of(Axis(0)) * 2,
88+
order,
89+
problem
90+
}
91+
}
92+
93+
pub fn precision(mut self, precision: A::Real) -> Self {
94+
self.precision = precision;
95+
96+
self
97+
}
98+
99+
pub fn maxiter(mut self, maxiter: usize) -> Self {
100+
self.maxiter = maxiter;
101+
102+
self
103+
104+
}
105+
106+
// calculate the eigenvalues once
107+
pub fn once(&self, num: usize) -> Result<TruncatedSvdResult<A>> {
108+
let (n,m) = (self.problem.rows(), self.problem.ncols());
109+
110+
let x = Array2::random((usize::min(n,m), num), Uniform::new(0.0, 1.0))
111+
.mapv(|x| NumCast::from(x).unwrap());
112+
113+
let res = if n > m {
114+
lobpcg(|y| self.problem.t().dot(&self.problem.dot(&y)), x, None, None, self.precision, self.maxiter, self.order.clone())
115+
} else {
116+
lobpcg(|y| self.problem.dot(&self.problem.t().dot(&y)), x, None, None, self.precision, self.maxiter, self.order.clone())
117+
};
118+
119+
match res {
120+
EigResult::Ok(vals, vecs, _) | EigResult::Err(vals, vecs, _, _) => {
121+
Ok(TruncatedSvdResult {
122+
problem: self.problem.clone(),
123+
eigvals: vals,
124+
eigvecs: vecs,
125+
ngm: n > m
126+
})
127+
},
128+
_ => panic!("")
129+
}
130+
}
131+
}
132+
133+
#[cfg(test)]
134+
mod tests {
135+
use super::TruncatedSvd;
136+
use super::Order;
137+
use ndarray::{arr1, Array2};
138+
139+
#[test]
140+
fn test_truncated_svd() {
141+
let diag = arr1(&[
142+
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20.,
143+
]);
144+
let a = Array2::from_diag(&diag);
145+
146+
let res = TruncatedSvd::new(a, Order::Largest)
147+
.precision(1e-5)
148+
.maxiter(500)
149+
.once(3)
150+
.unwrap();
151+
152+
dbg!(&res.values());
153+
}
154+
}

0 commit comments

Comments
 (0)