1+ ///! Truncated singular value decomposition
2+ ///!
3+ ///! This module computes the k largest/smallest singular values/vectors for a dense matrix.
14use super :: lobpcg:: { lobpcg, EigResult , Order } ;
25use crate :: error:: Result ;
36use crate :: { Lapack , Scalar } ;
47use ndarray:: prelude:: * ;
58use ndarray_rand:: rand_distr:: Uniform ;
69use ndarray_rand:: RandomExt ;
710use num_traits:: { Float , NumCast } ;
8- ///! Implements truncated singular value decomposition
9- ///
1011use std:: ops:: DivAssign ;
1112
12- /// The result of an eigenvalue decomposition for SVD
13+ /// The result of a eigenvalue decomposition, not yet transformed into singular values/vectors
1314///
1415/// Provides methods for either calculating just the singular values with reduced cost or the
15- /// vectors as well .
16+ /// vectors with additional cost of matrix multiplication .
1617#[ derive( Debug ) ]
1718pub struct TruncatedSvdResult < A > {
1819 eigvals : Array1 < A > ,
@@ -21,26 +22,31 @@ pub struct TruncatedSvdResult<A> {
2122 ngm : bool ,
2223}
2324
24- impl < A : Float + PartialOrd + DivAssign < A > + ' static > TruncatedSvdResult < A > {
25+ impl < A : Float + PartialOrd + DivAssign < A > + ' static + MagnitudeCorrection > TruncatedSvdResult < A > {
2526 /// Returns singular values ordered by magnitude with indices.
2627 fn singular_values_with_indices ( & self ) -> ( Array1 < A > , Vec < usize > ) {
27- // numerate and square root eigenvalues
28- let mut a = self . eigvals . iter ( ) . map ( |x| x . sqrt ( ) ) . enumerate ( ) . collect :: < Vec < _ > > ( ) ;
28+ // numerate eigenvalues
29+ let mut a = self . eigvals . iter ( ) . enumerate ( ) . collect :: < Vec < _ > > ( ) ;
2930
3031 // sort by magnitude
3132 a. sort_by ( |( _, x) , ( _, y) | x. partial_cmp ( & y) . unwrap ( ) . reverse ( ) ) ;
3233
34+ // calculate cut-off magnitude (borrowed from scipy)
35+ let cutoff = A :: epsilon ( ) * // float precision
36+ A :: correction ( ) * // correction term (see trait below)
37+ * a[ 0 ] . 1 ; // max eigenvalue
38+
3339 // filter low singular values away
3440 let ( values, indices) : ( Vec < A > , Vec < usize > ) = a
3541 . into_iter ( )
36- . filter ( |( _, x) | * x > NumCast :: from ( 1e-5 ) . unwrap ( ) )
37- . map ( |( a, b) | ( b, a) )
42+ . filter ( |( _, x) | * x > & cutoff )
43+ . map ( |( a, b) | ( b. sqrt ( ) , a) )
3844 . unzip ( ) ;
3945
4046 ( Array1 :: from ( values) , indices)
4147 }
4248
43- /// Returns singular values orderd by magnitude
49+ /// Returns singular values ordered by magnitude
4450 pub fn values ( & self ) -> Array1 < A > {
4551 let ( values, _) = self . singular_values_with_indices ( ) ;
4652
@@ -82,10 +88,8 @@ impl<A: Float + PartialOrd + DivAssign<A> + 'static> TruncatedSvdResult<A> {
8288
8389/// Truncated singular value decomposition
8490///
85- /// This struct wraps the LOBPCG algorithm and provides convenient builder-pattern access to
86- /// parameter like maximal iteration, precision and constraint matrix. Furthermore it allows
87- /// conversion into a iterative solver where each iteration step yields a new eigenvalue/vector
88- /// pair.
91+ /// Wraps the LOBPCG algorithm and provides convenient builder-pattern access to
92+ /// parameter like maximal iteration, precision and constraint matrix.
8993pub struct TruncatedSvd < A : Scalar > {
9094 order : Order ,
9195 problem : Array2 < A > ,
@@ -117,9 +121,15 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
117121
118122 // calculate the eigenvalue decomposition
119123 pub fn decompose ( self , num : usize ) -> Result < TruncatedSvdResult < A > > {
124+ if num < 1 {
125+ panic ! ( "The number of singular values to compute should be larger than zero!" ) ;
126+ }
127+
120128 let ( n, m) = ( self . problem . nrows ( ) , self . problem . ncols ( ) ) ;
121129
122- let x = Array2 :: random ( ( usize:: min ( n, m) , num) , Uniform :: new ( 0.0 , 1.0 ) ) . mapv ( |x| NumCast :: from ( x) . unwrap ( ) ) ;
130+ // generate initial matrix
131+ let x = Array2 :: random ( ( usize:: min ( n, m) , num) , Uniform :: new ( 0.0 , 1.0 ) )
132+ . mapv ( |x| NumCast :: from ( x) . unwrap ( ) ) ;
123133
124134 // square precision because the SVD squares the eigenvalue as well
125135 let precision = self . precision * self . precision ;
@@ -129,7 +139,7 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
129139 lobpcg (
130140 |y| self . problem . t ( ) . dot ( & self . problem . dot ( & y) ) ,
131141 x,
132- None ,
142+ |_| { } ,
133143 None ,
134144 precision,
135145 self . maxiter ,
@@ -139,7 +149,7 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
139149 lobpcg (
140150 |y| self . problem . dot ( & self . problem . t ( ) . dot ( & y) ) ,
141151 x,
142- None ,
152+ |_| { } ,
143153 None ,
144154 precision,
145155 self . maxiter ,
@@ -160,6 +170,22 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
160170 }
161171}
162172
173+ pub trait MagnitudeCorrection {
174+ fn correction ( ) -> Self ;
175+ }
176+
177+ impl MagnitudeCorrection for f32 {
178+ fn correction ( ) -> Self {
179+ 1.0e3
180+ }
181+ }
182+
183+ impl MagnitudeCorrection for f64 {
184+ fn correction ( ) -> Self {
185+ 1.0e6
186+ }
187+ }
188+
163189#[ cfg( test) ]
164190mod tests {
165191 use super :: Order ;
@@ -179,7 +205,7 @@ mod tests {
179205 . decompose ( 2 )
180206 . unwrap ( ) ;
181207
182- let ( _, sigma, _) = res. values_vecs ( ) ;
208+ let ( _, sigma, _) = res. values_vectors ( ) ;
183209
184210 close_l2 ( & sigma, & arr1 ( & [ 5.0 , 3.0 ] ) , 1e-5 ) ;
185211 }
0 commit comments