Skip to content

Commit 3334034

Browse files
committed
Add test for random matrix reconstruction with SVD
1 parent b0e1a40 commit 3334034

File tree

1 file changed

+40
-6
lines changed

1 file changed

+40
-6
lines changed

src/lobpcg/svd.rs

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ use crate::{Scalar, Lapack};
1010
use super::lobpcg::{lobpcg, EigResult, Order};
1111
use crate::error::Result;
1212

13+
/// The result of an eigenvalue decomposition for SVD
14+
///
15+
/// Provides methods for either calculating just the singular values with reduced cost or the
16+
/// vectors as well.
1317
#[derive(Debug)]
1418
pub struct TruncatedSvdResult<A> {
1519
eigvals: Array1<A>,
@@ -19,14 +23,18 @@ pub struct TruncatedSvdResult<A> {
1923
}
2024

2125
impl<A: Float + PartialOrd + DivAssign<A> + 'static> TruncatedSvdResult<A> {
26+
/// Returns singular values ordered by magnitude with indices.
2227
fn singular_values_with_indices(&self) -> (Array1<A>, Vec<usize>) {
28+
// numerate and square root eigenvalues
2329
let mut a = self.eigvals.iter()
2430
.map(|x| x.sqrt())
2531
.enumerate()
2632
.collect::<Vec<_>>();
2733

34+
// sort by magnitude
2835
a.sort_by(|(_,x), (_, y)| x.partial_cmp(&y).unwrap().reverse());
2936

37+
// filter low singular values away
3038
let (values, indices): (Vec<A>, Vec<usize>) = a.into_iter()
3139
.filter(|(_,x)| *x > NumCast::from(1e-5).unwrap())
3240
.map(|(a,b)| (b,a))
@@ -35,15 +43,18 @@ impl<A: Float + PartialOrd + DivAssign<A> + 'static> TruncatedSvdResult<A> {
3543
(Array1::from(values), indices)
3644
}
3745

46+
/// Returns singular values orderd by magnitude
3847
pub fn values(&self) -> Array1<A> {
3948
let (values, _) = self.singular_values_with_indices();
4049

4150
values
4251
}
4352

53+
/// Returns singular values, left-singular vectors and right-singular vectors
4454
pub fn values_vecs(&self) -> (Array2<A>, Array1<A>, Array2<A>) {
4555
let (values, indices) = self.singular_values_with_indices();
4656

57+
// branch n > m (for A is [n x m])
4758
let (u, v) = if self.ngm {
4859
let vlarge = self.eigvecs.select(Axis(1), &indices);
4960
let mut ularge = self.problem.dot(&vlarge);
@@ -105,18 +116,23 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
105116
}
106117

107118
// calculate the eigenvalues once
108-
pub fn once(&self, num: usize) -> Result<TruncatedSvdResult<A>> {
119+
pub fn once(self, num: usize) -> Result<TruncatedSvdResult<A>> {
109120
let (n,m) = (self.problem.nrows(), self.problem.ncols());
110121

111122
let x = Array2::random((usize::min(n,m), num), Uniform::new(0.0, 1.0))
112123
.mapv(|x| NumCast::from(x).unwrap());
113124

125+
// square precision because the SVD squares the eigenvalue as well
126+
let precision = self.precision * self.precision;
127+
128+
// use problem definition with less operations required
114129
let res = if n > m {
115-
lobpcg(|y| self.problem.t().dot(&self.problem.dot(&y)), x, None, None, self.precision, self.maxiter, self.order.clone())
130+
lobpcg(|y| self.problem.t().dot(&self.problem.dot(&y)), x, None, None, precision, self.maxiter, self.order.clone())
116131
} else {
117-
lobpcg(|y| self.problem.dot(&self.problem.t().dot(&y)), x, None, None, self.precision, self.maxiter, self.order.clone())
132+
lobpcg(|y| self.problem.dot(&self.problem.t().dot(&y)), x, None, None, precision, self.maxiter, self.order.clone())
118133
};
119134

135+
// convert into TruncatedSvdResult
120136
match res {
121137
EigResult::Ok(vals, vecs, _) | EigResult::Err(vals, vecs, _, _) => {
122138
Ok(TruncatedSvdResult {
@@ -136,7 +152,9 @@ mod tests {
136152
use crate::close_l2;
137153
use super::TruncatedSvd;
138154
use super::Order;
139-
use ndarray::{arr1, arr2};
155+
use ndarray::{arr1, arr2, Array2};
156+
use ndarray_rand::rand_distr::Uniform;
157+
use ndarray_rand::RandomExt;
140158

141159
#[test]
142160
fn test_truncated_svd() {
@@ -145,12 +163,28 @@ mod tests {
145163

146164
let res = TruncatedSvd::new(a, Order::Largest)
147165
.precision(1e-5)
148-
.maxiter(500)
166+
.maxiter(10)
149167
.once(2)
150168
.unwrap();
151169

152-
let (u, sigma, v_t) = res.values_vecs();
170+
let (_, sigma, _) = res.values_vecs();
153171

154172
close_l2(&sigma, &arr1(&[5.0, 3.0]), 1e-5);
155173
}
174+
175+
#[test]
176+
fn test_truncated_svd_random() {
177+
let a: Array2<f64> = Array2::random((50, 10), Uniform::new(0.0, 1.0));
178+
179+
let res = TruncatedSvd::new(a.clone(), Order::Largest)
180+
.precision(1e-5)
181+
.maxiter(10)
182+
.once(10)
183+
.unwrap();
184+
185+
let (u, sigma, v_t) = res.values_vecs();
186+
let reconstructed = u.dot(&Array2::from_diag(&sigma).dot(&v_t));
187+
188+
close_l2(&a, &reconstructed, 1e-5);
189+
}
156190
}

0 commit comments

Comments
 (0)