@@ -10,6 +10,10 @@ use crate::{Scalar, Lapack};
1010use super :: lobpcg:: { lobpcg, EigResult , Order } ;
1111use crate :: error:: Result ;
1212
13+ /// The result of an eigenvalue decomposition for SVD
14+ ///
15+ /// Provides methods for either calculating just the singular values with reduced cost or the
16+ /// vectors as well.
1317#[ derive( Debug ) ]
1418pub struct TruncatedSvdResult < A > {
1519 eigvals : Array1 < A > ,
@@ -19,14 +23,18 @@ pub struct TruncatedSvdResult<A> {
1923}
2024
2125impl < A : Float + PartialOrd + DivAssign < A > + ' static > TruncatedSvdResult < A > {
26+ /// Returns singular values ordered by magnitude with indices.
2227 fn singular_values_with_indices ( & self ) -> ( Array1 < A > , Vec < usize > ) {
28+ // numerate and square root eigenvalues
2329 let mut a = self . eigvals . iter ( )
2430 . map ( |x| x. sqrt ( ) )
2531 . enumerate ( )
2632 . collect :: < Vec < _ > > ( ) ;
2733
34+ // sort by magnitude
2835 a. sort_by ( |( _, x) , ( _, y) | x. partial_cmp ( & y) . unwrap ( ) . reverse ( ) ) ;
2936
37+ // filter low singular values away
3038 let ( values, indices) : ( Vec < A > , Vec < usize > ) = a. into_iter ( )
3139 . filter ( |( _, x) | * x > NumCast :: from ( 1e-5 ) . unwrap ( ) )
3240 . map ( |( a, b) | ( b, a) )
@@ -35,15 +43,18 @@ impl<A: Float + PartialOrd + DivAssign<A> + 'static> TruncatedSvdResult<A> {
3543 ( Array1 :: from ( values) , indices)
3644 }
3745
46+ /// Returns singular values orderd by magnitude
3847 pub fn values ( & self ) -> Array1 < A > {
3948 let ( values, _) = self . singular_values_with_indices ( ) ;
4049
4150 values
4251 }
4352
53+ /// Returns singular values, left-singular vectors and right-singular vectors
4454 pub fn values_vecs ( & self ) -> ( Array2 < A > , Array1 < A > , Array2 < A > ) {
4555 let ( values, indices) = self . singular_values_with_indices ( ) ;
4656
57+ // branch n > m (for A is [n x m])
4758 let ( u, v) = if self . ngm {
4859 let vlarge = self . eigvecs . select ( Axis ( 1 ) , & indices) ;
4960 let mut ularge = self . problem . dot ( & vlarge) ;
@@ -105,18 +116,23 @@ impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedSvd<A> {
105116 }
106117
107118 // calculate the eigenvalues once
108- pub fn once ( & self , num : usize ) -> Result < TruncatedSvdResult < A > > {
119+ pub fn once ( self , num : usize ) -> Result < TruncatedSvdResult < A > > {
109120 let ( n, m) = ( self . problem . nrows ( ) , self . problem . ncols ( ) ) ;
110121
111122 let x = Array2 :: random ( ( usize:: min ( n, m) , num) , Uniform :: new ( 0.0 , 1.0 ) )
112123 . mapv ( |x| NumCast :: from ( x) . unwrap ( ) ) ;
113124
125+ // square precision because the SVD squares the eigenvalue as well
126+ let precision = self . precision * self . precision ;
127+
128+ // use problem definition with less operations required
114129 let res = if n > m {
115- lobpcg ( |y| self . problem . t ( ) . dot ( & self . problem . dot ( & y) ) , x, None , None , self . precision , self . maxiter , self . order . clone ( ) )
130+ lobpcg ( |y| self . problem . t ( ) . dot ( & self . problem . dot ( & y) ) , x, None , None , precision, self . maxiter , self . order . clone ( ) )
116131 } else {
117- lobpcg ( |y| self . problem . dot ( & self . problem . t ( ) . dot ( & y) ) , x, None , None , self . precision , self . maxiter , self . order . clone ( ) )
132+ lobpcg ( |y| self . problem . dot ( & self . problem . t ( ) . dot ( & y) ) , x, None , None , precision, self . maxiter , self . order . clone ( ) )
118133 } ;
119134
135+ // convert into TruncatedSvdResult
120136 match res {
121137 EigResult :: Ok ( vals, vecs, _) | EigResult :: Err ( vals, vecs, _, _) => {
122138 Ok ( TruncatedSvdResult {
@@ -136,7 +152,9 @@ mod tests {
136152 use crate :: close_l2;
137153 use super :: TruncatedSvd ;
138154 use super :: Order ;
139- use ndarray:: { arr1, arr2} ;
155+ use ndarray:: { arr1, arr2, Array2 } ;
156+ use ndarray_rand:: rand_distr:: Uniform ;
157+ use ndarray_rand:: RandomExt ;
140158
141159 #[ test]
142160 fn test_truncated_svd ( ) {
@@ -145,12 +163,28 @@ mod tests {
145163
146164 let res = TruncatedSvd :: new ( a, Order :: Largest )
147165 . precision ( 1e-5 )
148- . maxiter ( 500 )
166+ . maxiter ( 10 )
149167 . once ( 2 )
150168 . unwrap ( ) ;
151169
152- let ( u , sigma, v_t ) = res. values_vecs ( ) ;
170+ let ( _ , sigma, _ ) = res. values_vecs ( ) ;
153171
154172 close_l2 ( & sigma, & arr1 ( & [ 5.0 , 3.0 ] ) , 1e-5 ) ;
155173 }
174+
175+ #[ test]
176+ fn test_truncated_svd_random ( ) {
177+ let a: Array2 < f64 > = Array2 :: random ( ( 50 , 10 ) , Uniform :: new ( 0.0 , 1.0 ) ) ;
178+
179+ let res = TruncatedSvd :: new ( a. clone ( ) , Order :: Largest )
180+ . precision ( 1e-5 )
181+ . maxiter ( 10 )
182+ . once ( 10 )
183+ . unwrap ( ) ;
184+
185+ let ( u, sigma, v_t) = res. values_vecs ( ) ;
186+ let reconstructed = u. dot ( & Array2 :: from_diag ( & sigma) . dot ( & v_t) ) ;
187+
188+ close_l2 ( & a, & reconstructed, 1e-5 ) ;
189+ }
156190}
0 commit comments