Skip to content

Commit 68d8ec7

Browse files
committed
Truncate svd with float eps; Take preconditioner as a closure
1 parent f2909f0 commit 68d8ec7

File tree

4 files changed

+73
-34
lines changed

4 files changed

+73
-34
lines changed

examples/truncated_eig.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use ndarray_linalg::*;
77
fn main() {
88
let n = 10;
99
let v = random_unitary(n);
10+
1011
// set eigenvalues in decreasing order
1112
let t = Array1::linspace(n as f64, -(n as f64), n);
1213

src/lobpcg/eig.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,27 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedEig<A> {
6464
let x = Array2::random((self.problem.len_of(Axis(0)), num), Uniform::new(0.0, 1.0))
6565
.mapv(|x| NumCast::from(x).unwrap());
6666

67-
lobpcg(
68-
|y| self.problem.dot(&y),
69-
x,
70-
self.preconditioner.clone(),
71-
self.constraints.clone(),
72-
self.precision,
73-
self.maxiter,
74-
self.order.clone(),
75-
)
67+
if let Some(ref preconditioner) = self.preconditioner {
68+
lobpcg(
69+
|y| self.problem.dot(&y),
70+
x,
71+
|mut y| y.assign(&preconditioner.dot(&y)),
72+
self.constraints.clone(),
73+
self.precision,
74+
self.maxiter,
75+
self.order.clone(),
76+
)
77+
} else {
78+
lobpcg(
79+
|y| self.problem.dot(&y),
80+
x,
81+
|_| {},
82+
self.constraints.clone(),
83+
self.precision,
84+
self.maxiter,
85+
self.order.clone(),
86+
)
87+
}
7688
}
7789
}
7890

src/lobpcg/lobpcg.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ fn orthonormalize<T: Scalar + Lapack>(v: Array2<T>) -> Result<(Array2<T>, Array2
147147
/// for it. All iterations are tracked and the optimal solution returned. In case of an error a
148148
/// special variant `EigResult::NotConverged` additionally carries the error. This can happen when
149149
/// the precision of the matrix is too low (switch from `f32` to `f64` for example).
150-
pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) -> Array2<A>>(
150+
pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) -> Array2<A>, G: Fn(ArrayViewMut2<A>)>(
151151
a: F,
152152
mut x: Array2<A>,
153-
m: Option<Array2<A>>,
153+
m: G,
154154
y: Option<Array2<A>>,
155155
tol: A::Real,
156156
maxiter: usize,
@@ -246,9 +246,9 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
246246

247247
// select active eigenvalues, apply pre-conditioner, orthogonalize to Y and orthonormalize
248248
let mut active_block_r = ndarray_mask(r.view(), &activemask);
249-
if let Some(ref m) = m {
250-
active_block_r = m.dot(&active_block_r);
251-
}
249+
// apply preconditioner
250+
m(active_block_r.view_mut());
251+
252252
if let (Some(ref y), Some(ref fact_yy)) = (&y, &fact_yy) {
253253
apply_constraints(active_block_r.view_mut(), fact_yy, y.view());
254254
}
@@ -453,7 +453,7 @@ mod tests {
453453
let n = a.len_of(Axis(0));
454454
let x: Array2<f64> = Array2::random((n, num), Uniform::new(0.0, 1.0));
455455

456-
let result = lobpcg(|y| a.dot(&y), x, None, None, 1e-10, 2 * n, order);
456+
let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-10, 2 * n, order);
457457
match result {
458458
EigResult::Ok(vals, _, r_norms) | EigResult::Err(vals, _, r_norms, _) => {
459459
// check convergence
@@ -510,7 +510,7 @@ mod tests {
510510
let x: Array2<f64> = Array2::random((10, 1), Uniform::new(0.0, 1.0));
511511
let y: Array2<f64> = arr2(&[[1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0.]]).reversed_axes();
512512

513-
let result = lobpcg(|y| a.dot(&y), x, None, Some(y), 1e-10, 100, Order::Smallest);
513+
let result = lobpcg(|y| a.dot(&y), x, |_| {}, Some(y), 1e-10, 100, Order::Smallest);
514514
dbg!(&result);
515515
match result {
516516
EigResult::Ok(vals, vecs, r_norms) | EigResult::Err(vals, vecs, r_norms, _) => {

src/lobpcg/svd.rs

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1+
///! Truncated singular value decomposition
2+
///!
3+
///! This module computes the k largest/smallest singular values/vectors for a dense matrix.
14
use super::lobpcg::{lobpcg, EigResult, Order};
25
use crate::error::Result;
36
use crate::{Lapack, Scalar};
47
use ndarray::prelude::*;
58
use ndarray_rand::rand_distr::Uniform;
69
use ndarray_rand::RandomExt;
710
use num_traits::{Float, NumCast};
8-
///! Implements truncated singular value decomposition
9-
///
1011
use std::ops::DivAssign;
1112

12-
/// The result of an eigenvalue decomposition for SVD
13+
/// The result of a eigenvalue decomposition, not yet transformed into singular values/vectors
1314
///
1415
/// Provides methods for either calculating just the singular values with reduced cost or the
15-
/// vectors as well.
16+
/// vectors with additional cost of matrix multiplication.
1617
#[derive(Debug)]
1718
pub struct TruncatedSvdResult<A> {
1819
eigvals: Array1<A>,
@@ -21,26 +22,31 @@ pub struct TruncatedSvdResult<A> {
2122
ngm: bool,
2223
}
2324

24-
impl<A: Float + PartialOrd + DivAssign<A> + 'static> TruncatedSvdResult<A> {
25+
impl<A: Float + PartialOrd + DivAssign<A> + 'static + MagnitudeCorrection> TruncatedSvdResult<A> {
2526
/// Returns singular values ordered by magnitude with indices.
2627
fn singular_values_with_indices(&self) -> (Array1<A>, Vec<usize>) {
27-
// numerate and square root eigenvalues
28-
let mut a = self.eigvals.iter().map(|x| x.sqrt()).enumerate().collect::<Vec<_>>();
28+
// numerate eigenvalues
29+
let mut a = self.eigvals.iter().enumerate().collect::<Vec<_>>();
2930

3031
// sort by magnitude
3132
a.sort_by(|(_, x), (_, y)| x.partial_cmp(&y).unwrap().reverse());
3233

34+
// calculate cut-off magnitude (borrowed from scipy)
35+
let cutoff = A::epsilon() * // float precision
36+
A::correction() * // correction term (see trait below)
37+
*a[0].1; // max eigenvalue
38+
3339
// filter low singular values away
3440
let (values, indices): (Vec<A>, Vec<usize>) = a
3541
.into_iter()
36-
.filter(|(_, x)| *x > NumCast::from(1e-5).unwrap())
37-
.map(|(a, b)| (b, a))
42+
.filter(|(_, x)| *x > &cutoff)
43+
.map(|(a, b)| (b.sqrt(), a))
3844
.unzip();
3945

4046
(Array1::from(values), indices)
4147
}
4248

43-
/// Returns singular values orderd by magnitude
49+
/// Returns singular values ordered by magnitude
4450
pub fn values(&self) -> Array1<A> {
4551
let (values, _) = self.singular_values_with_indices();
4652

@@ -82,10 +88,8 @@ impl<A: Float + PartialOrd + DivAssign<A> + 'static> TruncatedSvdResult<A> {
8288

8389
/// Truncated singular value decomposition
8490
///
85-
/// This struct wraps the LOBPCG algorithm and provides convenient builder-pattern access to
86-
/// parameter like maximal iteration, precision and constraint matrix. Furthermore it allows
87-
/// conversion into a iterative solver where each iteration step yields a new eigenvalue/vector
88-
/// pair.
91+
/// Wraps the LOBPCG algorithm and provides convenient builder-pattern access to
92+
/// parameter like maximal iteration, precision and constraint matrix.
8993
pub struct TruncatedSvd<A: Scalar> {
9094
order: Order,
9195
problem: Array2<A>,
@@ -117,9 +121,15 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
117121

118122
// calculate the eigenvalue decomposition
119123
pub fn decompose(self, num: usize) -> Result<TruncatedSvdResult<A>> {
124+
if num < 1 {
125+
panic!("The number of singular values to compute should be larger than zero!");
126+
}
127+
120128
let (n, m) = (self.problem.nrows(), self.problem.ncols());
121129

122-
let x = Array2::random((usize::min(n, m), num), Uniform::new(0.0, 1.0)).mapv(|x| NumCast::from(x).unwrap());
130+
// generate initial matrix
131+
let x = Array2::random((usize::min(n, m), num), Uniform::new(0.0, 1.0))
132+
.mapv(|x| NumCast::from(x).unwrap());
123133

124134
// square precision because the SVD squares the eigenvalue as well
125135
let precision = self.precision * self.precision;
@@ -129,7 +139,7 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
129139
lobpcg(
130140
|y| self.problem.t().dot(&self.problem.dot(&y)),
131141
x,
132-
None,
142+
|_| {},
133143
None,
134144
precision,
135145
self.maxiter,
@@ -139,7 +149,7 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
139149
lobpcg(
140150
|y| self.problem.dot(&self.problem.t().dot(&y)),
141151
x,
142-
None,
152+
|_| {},
143153
None,
144154
precision,
145155
self.maxiter,
@@ -160,6 +170,22 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
160170
}
161171
}
162172

173+
pub trait MagnitudeCorrection {
174+
fn correction() -> Self;
175+
}
176+
177+
impl MagnitudeCorrection for f32 {
178+
fn correction() -> Self {
179+
1.0e3
180+
}
181+
}
182+
183+
impl MagnitudeCorrection for f64 {
184+
fn correction() -> Self {
185+
1.0e6
186+
}
187+
}
188+
163189
#[cfg(test)]
164190
mod tests {
165191
use super::Order;
@@ -179,7 +205,7 @@ mod tests {
179205
.decompose(2)
180206
.unwrap();
181207

182-
let (_, sigma, _) = res.values_vecs();
208+
let (_, sigma, _) = res.values_vectors();
183209

184210
close_l2(&sigma, &arr1(&[5.0, 3.0]), 1e-5);
185211
}

0 commit comments

Comments
 (0)