@@ -14,9 +14,303 @@ pub trait SVDDC_: Scalar {
1414 /// |:-------|:-------|:-------|:-------|
1515 /// | sgesdd | dgesdd | cgesdd | zgesdd |
1616 ///
17- fn svddc ( l : MatrixLayout , jobz : JobSvd , a : & mut [ Self ] ) -> Result < SvdOwned < Self > > ;
17+ fn svddc ( layout : MatrixLayout , jobz : JobSvd , a : & mut [ Self ] ) -> Result < SvdOwned < Self > > ;
1818}
1919
20+ pub struct SvdDcWork < T : Scalar > {
21+ pub jobz : JobSvd ,
22+ pub layout : MatrixLayout ,
23+ pub s : Vec < MaybeUninit < T :: Real > > ,
24+ pub u : Option < Vec < MaybeUninit < T > > > ,
25+ pub vt : Option < Vec < MaybeUninit < T > > > ,
26+ pub work : Vec < MaybeUninit < T > > ,
27+ pub iwork : Vec < MaybeUninit < i32 > > ,
28+ pub rwork : Option < Vec < MaybeUninit < T :: Real > > > ,
29+ }
30+
31+ pub trait SvdDcWorkImpl : Sized {
32+ type Elem : Scalar ;
33+ fn new ( layout : MatrixLayout , jobz : JobSvd ) -> Result < Self > ;
34+ fn calc ( & mut self , a : & mut [ Self :: Elem ] ) -> Result < SvdRef < Self :: Elem > > ;
35+ fn eval ( self , a : & mut [ Self :: Elem ] ) -> Result < SvdOwned < Self :: Elem > > ;
36+ }
37+
38+ macro_rules! impl_svd_dc_work_c {
39+ ( $s: ty, $sdd: path) => {
40+ impl SvdDcWorkImpl for SvdDcWork <$s> {
41+ type Elem = $s;
42+
43+ fn new( layout: MatrixLayout , jobz: JobSvd ) -> Result <Self > {
44+ let m = layout. lda( ) ;
45+ let n = layout. len( ) ;
46+ let k = m. min( n) ;
47+ let ( u_col, vt_row) = match jobz {
48+ JobSvd :: All | JobSvd :: None => ( m, n) ,
49+ JobSvd :: Some => ( k, k) ,
50+ } ;
51+
52+ let mut s = vec_uninit( k as usize ) ;
53+ let ( mut u, mut vt) = match jobz {
54+ JobSvd :: All => (
55+ Some ( vec_uninit( ( m * m) as usize ) ) ,
56+ Some ( vec_uninit( ( n * n) as usize ) ) ,
57+ ) ,
58+ JobSvd :: Some => (
59+ Some ( vec_uninit( ( m * u_col) as usize ) ) ,
60+ Some ( vec_uninit( ( n * vt_row) as usize ) ) ,
61+ ) ,
62+ JobSvd :: None => ( None , None ) ,
63+ } ;
64+ let mut iwork = vec_uninit( 8 * k as usize ) ;
65+
66+ let mx = n. max( m) as usize ;
67+ let mn = n. min( m) as usize ;
68+ let lrwork = match jobz {
69+ JobSvd :: None => 7 * mn,
70+ _ => std:: cmp:: max( 5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn) ,
71+ } ;
72+ let mut rwork = vec_uninit( lrwork) ;
73+
74+ let mut info = 0 ;
75+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
76+ unsafe {
77+ $sdd(
78+ jobz. as_ptr( ) ,
79+ & m,
80+ & n,
81+ std:: ptr:: null_mut( ) ,
82+ & m,
83+ AsPtr :: as_mut_ptr( & mut s) ,
84+ AsPtr :: as_mut_ptr( u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
85+ & m,
86+ AsPtr :: as_mut_ptr( vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
87+ & vt_row,
88+ AsPtr :: as_mut_ptr( & mut work_size) ,
89+ & ( -1 ) ,
90+ AsPtr :: as_mut_ptr( & mut rwork) ,
91+ AsPtr :: as_mut_ptr( & mut iwork) ,
92+ & mut info,
93+ ) ;
94+ }
95+ info. as_lapack_result( ) ?;
96+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
97+ let work = vec_uninit( lwork) ;
98+ Ok ( SvdDcWork {
99+ layout,
100+ jobz,
101+ iwork,
102+ work,
103+ rwork: Some ( rwork) ,
104+ u,
105+ vt,
106+ s,
107+ } )
108+ }
109+
110+ fn calc( & mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdRef <Self :: Elem >> {
111+ let m = self . layout. lda( ) ;
112+ let n = self . layout. len( ) ;
113+ let k = m. min( n) ;
114+ let ( _, vt_row) = match self . jobz {
115+ JobSvd :: All | JobSvd :: None => ( m, n) ,
116+ JobSvd :: Some => ( k, k) ,
117+ } ;
118+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
119+
120+ let mut info = 0 ;
121+ unsafe {
122+ $sdd(
123+ self . jobz. as_ptr( ) ,
124+ & m,
125+ & n,
126+ AsPtr :: as_mut_ptr( a) ,
127+ & m,
128+ AsPtr :: as_mut_ptr( & mut self . s) ,
129+ AsPtr :: as_mut_ptr(
130+ self . u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
131+ ) ,
132+ & m,
133+ AsPtr :: as_mut_ptr(
134+ self . vt
135+ . as_mut( )
136+ . map( |x| x. as_mut_slice( ) )
137+ . unwrap_or( & mut [ ] ) ,
138+ ) ,
139+ & vt_row,
140+ AsPtr :: as_mut_ptr( & mut self . work) ,
141+ & lwork,
142+ AsPtr :: as_mut_ptr( self . rwork. as_mut( ) . unwrap( ) ) ,
143+ AsPtr :: as_mut_ptr( & mut self . iwork) ,
144+ & mut info,
145+ ) ;
146+ }
147+ info. as_lapack_result( ) ?;
148+
149+ let s = unsafe { self . s. slice_assume_init_ref( ) } ;
150+ let u = self
151+ . u
152+ . as_ref( )
153+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
154+ let vt = self
155+ . vt
156+ . as_ref( )
157+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
158+
159+ Ok ( match self . layout {
160+ MatrixLayout :: F { .. } => SvdRef { s, u, vt } ,
161+ MatrixLayout :: C { .. } => SvdRef { s, u: vt, vt: u } ,
162+ } )
163+ }
164+
165+ fn eval( mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdOwned <Self :: Elem >> {
166+ let _ref = self . calc( a) ?;
167+ let s = unsafe { self . s. assume_init( ) } ;
168+ let u = self . u. map( |v| unsafe { v. assume_init( ) } ) ;
169+ let vt = self . vt. map( |v| unsafe { v. assume_init( ) } ) ;
170+ Ok ( match self . layout {
171+ MatrixLayout :: F { .. } => SvdOwned { s, u, vt } ,
172+ MatrixLayout :: C { .. } => SvdOwned { s, u: vt, vt: u } ,
173+ } )
174+ }
175+ }
176+ } ;
177+ }
178+ impl_svd_dc_work_c ! ( c64, lapack_sys:: zgesdd_) ;
179+ impl_svd_dc_work_c ! ( c32, lapack_sys:: cgesdd_) ;
180+
181+ macro_rules! impl_svd_dc_work_r {
182+ ( $s: ty, $sdd: path) => {
183+ impl SvdDcWorkImpl for SvdDcWork <$s> {
184+ type Elem = $s;
185+
186+ fn new( layout: MatrixLayout , jobz: JobSvd ) -> Result <Self > {
187+ let m = layout. lda( ) ;
188+ let n = layout. len( ) ;
189+ let k = m. min( n) ;
190+ let ( u_col, vt_row) = match jobz {
191+ JobSvd :: All | JobSvd :: None => ( m, n) ,
192+ JobSvd :: Some => ( k, k) ,
193+ } ;
194+
195+ let mut s = vec_uninit( k as usize ) ;
196+ let ( mut u, mut vt) = match jobz {
197+ JobSvd :: All => (
198+ Some ( vec_uninit( ( m * m) as usize ) ) ,
199+ Some ( vec_uninit( ( n * n) as usize ) ) ,
200+ ) ,
201+ JobSvd :: Some => (
202+ Some ( vec_uninit( ( m * u_col) as usize ) ) ,
203+ Some ( vec_uninit( ( n * vt_row) as usize ) ) ,
204+ ) ,
205+ JobSvd :: None => ( None , None ) ,
206+ } ;
207+ let mut iwork = vec_uninit( 8 * k as usize ) ;
208+
209+ let mut info = 0 ;
210+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
211+ unsafe {
212+ $sdd(
213+ jobz. as_ptr( ) ,
214+ & m,
215+ & n,
216+ std:: ptr:: null_mut( ) ,
217+ & m,
218+ AsPtr :: as_mut_ptr( & mut s) ,
219+ AsPtr :: as_mut_ptr( u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
220+ & m,
221+ AsPtr :: as_mut_ptr( vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
222+ & vt_row,
223+ AsPtr :: as_mut_ptr( & mut work_size) ,
224+ & ( -1 ) ,
225+ AsPtr :: as_mut_ptr( & mut iwork) ,
226+ & mut info,
227+ ) ;
228+ }
229+ info. as_lapack_result( ) ?;
230+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
231+ let work = vec_uninit( lwork) ;
232+ Ok ( SvdDcWork {
233+ layout,
234+ jobz,
235+ iwork,
236+ work,
237+ rwork: None ,
238+ u,
239+ vt,
240+ s,
241+ } )
242+ }
243+
244+ fn calc( & mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdRef <Self :: Elem >> {
245+ let m = self . layout. lda( ) ;
246+ let n = self . layout. len( ) ;
247+ let k = m. min( n) ;
248+ let ( _, vt_row) = match self . jobz {
249+ JobSvd :: All | JobSvd :: None => ( m, n) ,
250+ JobSvd :: Some => ( k, k) ,
251+ } ;
252+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
253+
254+ let mut info = 0 ;
255+ unsafe {
256+ $sdd(
257+ self . jobz. as_ptr( ) ,
258+ & m,
259+ & n,
260+ AsPtr :: as_mut_ptr( a) ,
261+ & m,
262+ AsPtr :: as_mut_ptr( & mut self . s) ,
263+ AsPtr :: as_mut_ptr(
264+ self . u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
265+ ) ,
266+ & m,
267+ AsPtr :: as_mut_ptr(
268+ self . vt
269+ . as_mut( )
270+ . map( |x| x. as_mut_slice( ) )
271+ . unwrap_or( & mut [ ] ) ,
272+ ) ,
273+ & vt_row,
274+ AsPtr :: as_mut_ptr( & mut self . work) ,
275+ & lwork,
276+ AsPtr :: as_mut_ptr( & mut self . iwork) ,
277+ & mut info,
278+ ) ;
279+ }
280+ info. as_lapack_result( ) ?;
281+
282+ let s = unsafe { self . s. slice_assume_init_ref( ) } ;
283+ let u = self
284+ . u
285+ . as_ref( )
286+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
287+ let vt = self
288+ . vt
289+ . as_ref( )
290+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
291+
292+ Ok ( match self . layout {
293+ MatrixLayout :: F { .. } => SvdRef { s, u, vt } ,
294+ MatrixLayout :: C { .. } => SvdRef { s, u: vt, vt: u } ,
295+ } )
296+ }
297+
298+ fn eval( mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdOwned <Self :: Elem >> {
299+ let _ref = self . calc( a) ?;
300+ let s = unsafe { self . s. assume_init( ) } ;
301+ let u = self . u. map( |v| unsafe { v. assume_init( ) } ) ;
302+ let vt = self . vt. map( |v| unsafe { v. assume_init( ) } ) ;
303+ Ok ( match self . layout {
304+ MatrixLayout :: F { .. } => SvdOwned { s, u, vt } ,
305+ MatrixLayout :: C { .. } => SvdOwned { s, u: vt, vt: u } ,
306+ } )
307+ }
308+ }
309+ } ;
310+ }
311+ impl_svd_dc_work_r ! ( f64 , lapack_sys:: dgesdd_) ;
312+ impl_svd_dc_work_r ! ( f32 , lapack_sys:: sgesdd_) ;
313+
20314macro_rules! impl_svddc {
21315 ( @real, $scalar: ty, $gesdd: path) => {
22316 impl_svddc!( @body, $scalar, $gesdd, ) ;
0 commit comments