@@ -30,6 +30,307 @@ pub trait SVD_: Scalar {
3030 -> Result < SVDOutput < Self > > ;
3131}
3232
33+ pub struct SvdWork < T : Scalar > {
34+ pub ju : JobSvd ,
35+ pub jvt : JobSvd ,
36+ pub layout : MatrixLayout ,
37+ pub s : Vec < MaybeUninit < T :: Real > > ,
38+ pub u : Option < Vec < MaybeUninit < T > > > ,
39+ pub vt : Option < Vec < MaybeUninit < T > > > ,
40+ pub work : Vec < MaybeUninit < T > > ,
41+ pub rwork : Option < Vec < MaybeUninit < T :: Real > > > ,
42+ }
43+
44+ #[ derive( Debug , Clone ) ]
45+ pub struct SvdRef < ' work , T : Scalar > {
46+ pub s : & ' work [ T :: Real ] ,
47+ pub u : Option < & ' work [ T ] > ,
48+ pub vt : Option < & ' work [ T ] > ,
49+ }
50+
51+ #[ derive( Debug , Clone ) ]
52+ pub struct SvdOwned < T : Scalar > {
53+ pub s : Vec < T :: Real > ,
54+ pub u : Option < Vec < T > > ,
55+ pub vt : Option < Vec < T > > ,
56+ }
57+
58+ pub trait SvdWorkImpl : Sized {
59+ type Elem : Scalar ;
60+ fn new ( layout : MatrixLayout , calc_u : bool , calc_vt : bool ) -> Result < Self > ;
61+ fn calc ( & mut self , a : & mut [ Self :: Elem ] ) -> Result < SvdRef < Self :: Elem > > ;
62+ fn eval ( self , a : & mut [ Self :: Elem ] ) -> Result < SvdOwned < Self :: Elem > > ;
63+ }
64+
65+ macro_rules! impl_svd_work_c {
66+ ( $s: ty, $svd: path) => {
67+ impl SvdWorkImpl for SvdWork <$s> {
68+ type Elem = $s;
69+
70+ fn new( layout: MatrixLayout , calc_u: bool , calc_vt: bool ) -> Result <Self > {
71+ let ju = match layout {
72+ MatrixLayout :: F { .. } => JobSvd :: from_bool( calc_u) ,
73+ MatrixLayout :: C { .. } => JobSvd :: from_bool( calc_vt) ,
74+ } ;
75+ let jvt = match layout {
76+ MatrixLayout :: F { .. } => JobSvd :: from_bool( calc_vt) ,
77+ MatrixLayout :: C { .. } => JobSvd :: from_bool( calc_u) ,
78+ } ;
79+
80+ let m = layout. lda( ) ;
81+ let mut u = match ju {
82+ JobSvd :: All => Some ( vec_uninit( ( m * m) as usize ) ) ,
83+ JobSvd :: None => None ,
84+ _ => unimplemented!( "SVD with partial vector output is not supported yet" ) ,
85+ } ;
86+
87+ let n = layout. len( ) ;
88+ let mut vt = match jvt {
89+ JobSvd :: All => Some ( vec_uninit( ( n * n) as usize ) ) ,
90+ JobSvd :: None => None ,
91+ _ => unimplemented!( "SVD with partial vector output is not supported yet" ) ,
92+ } ;
93+
94+ let k = std:: cmp:: min( m, n) ;
95+ let mut s = vec_uninit( k as usize ) ;
96+ let mut rwork = vec_uninit( 5 * k as usize ) ;
97+
98+ // eval work size
99+ let mut info = 0 ;
100+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
101+ unsafe {
102+ $svd(
103+ ju. as_ptr( ) ,
104+ jvt. as_ptr( ) ,
105+ & m,
106+ & n,
107+ std:: ptr:: null_mut( ) ,
108+ & m,
109+ AsPtr :: as_mut_ptr( & mut s) ,
110+ AsPtr :: as_mut_ptr( u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
111+ & m,
112+ AsPtr :: as_mut_ptr( vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
113+ & n,
114+ AsPtr :: as_mut_ptr( & mut work_size) ,
115+ & ( -1 ) ,
116+ AsPtr :: as_mut_ptr( & mut rwork) ,
117+ & mut info,
118+ ) ;
119+ }
120+ info. as_lapack_result( ) ?;
121+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
122+ let work = vec_uninit( lwork) ;
123+ Ok ( SvdWork {
124+ layout,
125+ ju,
126+ jvt,
127+ s,
128+ u,
129+ vt,
130+ work,
131+ rwork: Some ( rwork) ,
132+ } )
133+ }
134+
135+ fn calc( & mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdRef <Self :: Elem >> {
136+ let m = self . layout. lda( ) ;
137+ let n = self . layout. len( ) ;
138+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
139+
140+ let mut info = 0 ;
141+ unsafe {
142+ $svd(
143+ self . ju. as_ptr( ) ,
144+ self . jvt. as_ptr( ) ,
145+ & m,
146+ & n,
147+ AsPtr :: as_mut_ptr( a) ,
148+ & m,
149+ AsPtr :: as_mut_ptr( & mut self . s) ,
150+ AsPtr :: as_mut_ptr(
151+ self . u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
152+ ) ,
153+ & m,
154+ AsPtr :: as_mut_ptr(
155+ self . vt
156+ . as_mut( )
157+ . map( |x| x. as_mut_slice( ) )
158+ . unwrap_or( & mut [ ] ) ,
159+ ) ,
160+ & n,
161+ AsPtr :: as_mut_ptr( & mut self . work) ,
162+ & ( lwork as i32 ) ,
163+ AsPtr :: as_mut_ptr( self . rwork. as_mut( ) . unwrap( ) ) ,
164+ & mut info,
165+ ) ;
166+ }
167+ info. as_lapack_result( ) ?;
168+
169+ let s = unsafe { self . s. slice_assume_init_ref( ) } ;
170+ let u = self
171+ . u
172+ . as_ref( )
173+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
174+ let vt = self
175+ . vt
176+ . as_ref( )
177+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
178+
179+ match self . layout {
180+ MatrixLayout :: F { .. } => Ok ( SvdRef { s, u, vt } ) ,
181+ MatrixLayout :: C { .. } => Ok ( SvdRef { s, u: vt, vt: u } ) ,
182+ }
183+ }
184+
185+ fn eval( mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdOwned <Self :: Elem >> {
186+ let _ref = self . calc( a) ?;
187+ let s = unsafe { self . s. assume_init( ) } ;
188+ let u = self . u. map( |v| unsafe { v. assume_init( ) } ) ;
189+ let vt = self . vt. map( |v| unsafe { v. assume_init( ) } ) ;
190+ match self . layout {
191+ MatrixLayout :: F { .. } => Ok ( SvdOwned { s, u, vt } ) ,
192+ MatrixLayout :: C { .. } => Ok ( SvdOwned { s, u: vt, vt: u } ) ,
193+ }
194+ }
195+ }
196+ } ;
197+ }
198+ impl_svd_work_c ! ( c64, lapack_sys:: zgesvd_) ;
199+ impl_svd_work_c ! ( c32, lapack_sys:: cgesvd_) ;
200+
201+ macro_rules! impl_svd_work_r {
202+ ( $s: ty, $svd: path) => {
203+ impl SvdWorkImpl for SvdWork <$s> {
204+ type Elem = $s;
205+
206+ fn new( layout: MatrixLayout , calc_u: bool , calc_vt: bool ) -> Result <Self > {
207+ let ju = match layout {
208+ MatrixLayout :: F { .. } => JobSvd :: from_bool( calc_u) ,
209+ MatrixLayout :: C { .. } => JobSvd :: from_bool( calc_vt) ,
210+ } ;
211+ let jvt = match layout {
212+ MatrixLayout :: F { .. } => JobSvd :: from_bool( calc_vt) ,
213+ MatrixLayout :: C { .. } => JobSvd :: from_bool( calc_u) ,
214+ } ;
215+
216+ let m = layout. lda( ) ;
217+ let mut u = match ju {
218+ JobSvd :: All => Some ( vec_uninit( ( m * m) as usize ) ) ,
219+ JobSvd :: None => None ,
220+ _ => unimplemented!( "SVD with partial vector output is not supported yet" ) ,
221+ } ;
222+
223+ let n = layout. len( ) ;
224+ let mut vt = match jvt {
225+ JobSvd :: All => Some ( vec_uninit( ( n * n) as usize ) ) ,
226+ JobSvd :: None => None ,
227+ _ => unimplemented!( "SVD with partial vector output is not supported yet" ) ,
228+ } ;
229+
230+ let k = std:: cmp:: min( m, n) ;
231+ let mut s = vec_uninit( k as usize ) ;
232+
233+ // eval work size
234+ let mut info = 0 ;
235+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
236+ unsafe {
237+ $svd(
238+ ju. as_ptr( ) ,
239+ jvt. as_ptr( ) ,
240+ & m,
241+ & n,
242+ std:: ptr:: null_mut( ) ,
243+ & m,
244+ AsPtr :: as_mut_ptr( & mut s) ,
245+ AsPtr :: as_mut_ptr( u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
246+ & m,
247+ AsPtr :: as_mut_ptr( vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
248+ & n,
249+ AsPtr :: as_mut_ptr( & mut work_size) ,
250+ & ( -1 ) ,
251+ & mut info,
252+ ) ;
253+ }
254+ info. as_lapack_result( ) ?;
255+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
256+ let work = vec_uninit( lwork) ;
257+ Ok ( SvdWork {
258+ layout,
259+ ju,
260+ jvt,
261+ s,
262+ u,
263+ vt,
264+ work,
265+ rwork: None ,
266+ } )
267+ }
268+
269+ fn calc( & mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdRef <Self :: Elem >> {
270+ let m = self . layout. lda( ) ;
271+ let n = self . layout. len( ) ;
272+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
273+
274+ let mut info = 0 ;
275+ unsafe {
276+ $svd(
277+ self . ju. as_ptr( ) ,
278+ self . jvt. as_ptr( ) ,
279+ & m,
280+ & n,
281+ AsPtr :: as_mut_ptr( a) ,
282+ & m,
283+ AsPtr :: as_mut_ptr( & mut self . s) ,
284+ AsPtr :: as_mut_ptr(
285+ self . u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
286+ ) ,
287+ & m,
288+ AsPtr :: as_mut_ptr(
289+ self . vt
290+ . as_mut( )
291+ . map( |x| x. as_mut_slice( ) )
292+ . unwrap_or( & mut [ ] ) ,
293+ ) ,
294+ & n,
295+ AsPtr :: as_mut_ptr( & mut self . work) ,
296+ & ( lwork as i32 ) ,
297+ & mut info,
298+ ) ;
299+ }
300+ info. as_lapack_result( ) ?;
301+
302+ let s = unsafe { self . s. slice_assume_init_ref( ) } ;
303+ let u = self
304+ . u
305+ . as_ref( )
306+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
307+ let vt = self
308+ . vt
309+ . as_ref( )
310+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
311+
312+ match self . layout {
313+ MatrixLayout :: F { .. } => Ok ( SvdRef { s, u, vt } ) ,
314+ MatrixLayout :: C { .. } => Ok ( SvdRef { s, u: vt, vt: u } ) ,
315+ }
316+ }
317+
318+ fn eval( mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdOwned <Self :: Elem >> {
319+ let _ref = self . calc( a) ?;
320+ let s = unsafe { self . s. assume_init( ) } ;
321+ let u = self . u. map( |v| unsafe { v. assume_init( ) } ) ;
322+ let vt = self . vt. map( |v| unsafe { v. assume_init( ) } ) ;
323+ match self . layout {
324+ MatrixLayout :: F { .. } => Ok ( SvdOwned { s, u, vt } ) ,
325+ MatrixLayout :: C { .. } => Ok ( SvdOwned { s, u: vt, vt: u } ) ,
326+ }
327+ }
328+ }
329+ } ;
330+ }
331+ impl_svd_work_r ! ( f64 , lapack_sys:: dgesvd_) ;
332+ impl_svd_work_r ! ( f32 , lapack_sys:: sgesvd_) ;
333+
33334macro_rules! impl_svd {
34335 ( @real, $scalar: ty, $gesvd: path) => {
35336 impl_svd!( @body, $scalar, $gesvd, ) ;
0 commit comments