@@ -12,6 +12,14 @@ pub struct LeastSquaresOwned<A: Scalar> {
1212 pub rank : i32 ,
1313}
1414
15+ /// Result of LeastSquares
16+ pub struct LeastSquaresRef < ' work , A : Scalar > {
17+ /// singular values
18+ pub singular_values : & ' work [ A :: Real ] ,
19+ /// The rank of the input matrix A
20+ pub rank : i32 ,
21+ }
22+
1523#[ cfg_attr( doc, katexit:: katexit) ]
1624/// Solve least square problem
1725pub trait LeastSquaresSvdDivideConquer_ : Scalar {
@@ -29,8 +37,325 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
2937 a : & mut [ Self ] ,
3038 b_layout : MatrixLayout ,
3139 b : & mut [ Self ] ,
32- ) -> Result < LeastSquaresOutput < Self > > ;
40+ ) -> Result < LeastSquaresOwned < Self > > ;
41+ }
42+
43+ pub struct LeastSquaresWork < T : Scalar > {
44+ pub a_layout : MatrixLayout ,
45+ pub b_layout : MatrixLayout ,
46+ pub singular_values : Vec < MaybeUninit < T :: Real > > ,
47+ pub work : Vec < MaybeUninit < T > > ,
48+ pub iwork : Vec < MaybeUninit < i32 > > ,
49+ pub rwork : Option < Vec < MaybeUninit < T :: Real > > > ,
50+ }
51+
52+ pub trait LeastSquaresWorkImpl : Sized {
53+ type Elem : Scalar ;
54+ fn new ( a_layout : MatrixLayout , b_layout : MatrixLayout ) -> Result < Self > ;
55+ fn calc (
56+ & mut self ,
57+ a : & mut [ Self :: Elem ] ,
58+ b : & mut [ Self :: Elem ] ,
59+ ) -> Result < LeastSquaresRef < Self :: Elem > > ;
60+ fn eval (
61+ self ,
62+ a : & mut [ Self :: Elem ] ,
63+ b : & mut [ Self :: Elem ] ,
64+ ) -> Result < LeastSquaresOwned < Self :: Elem > > ;
65+ }
66+
67+ macro_rules! impl_least_squares_work_c {
68+ ( $c: ty, $lsd: path) => {
69+ impl LeastSquaresWorkImpl for LeastSquaresWork <$c> {
70+ type Elem = $c;
71+
72+ fn new( a_layout: MatrixLayout , b_layout: MatrixLayout ) -> Result <Self > {
73+ let ( m, n) = a_layout. size( ) ;
74+ let ( m_, nrhs) = b_layout. size( ) ;
75+ let k = m. min( n) ;
76+ assert!( m_ >= m) ;
77+
78+ let rcond = -1. ;
79+ let mut singular_values = vec_uninit( k as usize ) ;
80+ let mut rank: i32 = 0 ;
81+
82+ // eval work size
83+ let mut info = 0 ;
84+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
85+ let mut iwork_size = [ 0 ] ;
86+ let mut rwork = [ <Self :: Elem as Scalar >:: Real :: zero( ) ] ;
87+ unsafe {
88+ $lsd(
89+ & m,
90+ & n,
91+ & nrhs,
92+ std:: ptr:: null_mut( ) ,
93+ & a_layout. lda( ) ,
94+ std:: ptr:: null_mut( ) ,
95+ & b_layout. lda( ) ,
96+ AsPtr :: as_mut_ptr( & mut singular_values) ,
97+ & rcond,
98+ & mut rank,
99+ AsPtr :: as_mut_ptr( & mut work_size) ,
100+ & ( -1 ) ,
101+ AsPtr :: as_mut_ptr( & mut rwork) ,
102+ iwork_size. as_mut_ptr( ) ,
103+ & mut info,
104+ )
105+ } ;
106+ info. as_lapack_result( ) ?;
107+
108+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
109+ let liwork = iwork_size[ 0 ] . to_usize( ) . unwrap( ) ;
110+ let lrwork = rwork[ 0 ] . to_usize( ) . unwrap( ) ;
111+
112+ let work = vec_uninit( lwork) ;
113+ let iwork = vec_uninit( liwork) ;
114+ let rwork = vec_uninit( lrwork) ;
115+
116+ Ok ( LeastSquaresWork {
117+ a_layout,
118+ b_layout,
119+ work,
120+ iwork,
121+ rwork: Some ( rwork) ,
122+ singular_values,
123+ } )
124+ }
125+
126+ fn calc(
127+ & mut self ,
128+ a: & mut [ Self :: Elem ] ,
129+ b: & mut [ Self :: Elem ] ,
130+ ) -> Result <LeastSquaresRef <Self :: Elem >> {
131+ let ( m, n) = self . a_layout. size( ) ;
132+ let ( m_, nrhs) = self . b_layout. size( ) ;
133+ assert!( m_ >= m) ;
134+
135+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
136+
137+ // Transpose if a is C-continuous
138+ let mut a_t = None ;
139+ let a_layout = match self . a_layout {
140+ MatrixLayout :: C { .. } => {
141+ let ( layout, t) = transpose( self . a_layout, a) ;
142+ a_t = Some ( t) ;
143+ layout
144+ }
145+ MatrixLayout :: F { .. } => self . a_layout,
146+ } ;
147+
148+ // Transpose if b is C-continuous
149+ let mut b_t = None ;
150+ let b_layout = match self . b_layout {
151+ MatrixLayout :: C { .. } => {
152+ let ( layout, t) = transpose( self . b_layout, b) ;
153+ b_t = Some ( t) ;
154+ layout
155+ }
156+ MatrixLayout :: F { .. } => self . b_layout,
157+ } ;
158+
159+ let rcond: <Self :: Elem as Scalar >:: Real = -1. ;
160+ let mut rank: i32 = 0 ;
161+
162+ let mut info = 0 ;
163+ unsafe {
164+ $lsd(
165+ & m,
166+ & n,
167+ & nrhs,
168+ AsPtr :: as_mut_ptr( a_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( a) ) ,
169+ & a_layout. lda( ) ,
170+ AsPtr :: as_mut_ptr( b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ) ,
171+ & b_layout. lda( ) ,
172+ AsPtr :: as_mut_ptr( & mut self . singular_values) ,
173+ & rcond,
174+ & mut rank,
175+ AsPtr :: as_mut_ptr( & mut self . work) ,
176+ & lwork,
177+ AsPtr :: as_mut_ptr( self . rwork. as_mut( ) . unwrap( ) ) ,
178+ AsPtr :: as_mut_ptr( & mut self . iwork) ,
179+ & mut info,
180+ ) ;
181+ }
182+ info. as_lapack_result( ) ?;
183+
184+ let singular_values = unsafe { self . singular_values. slice_assume_init_ref( ) } ;
185+
186+ // Skip a_t -> a transpose because A has been destroyed
187+ // Re-transpose b
188+ if let Some ( b_t) = b_t {
189+ transpose_over( b_layout, & b_t, b) ;
190+ }
191+
192+ Ok ( LeastSquaresRef {
193+ singular_values,
194+ rank,
195+ } )
196+ }
197+
198+ fn eval(
199+ mut self ,
200+ a: & mut [ Self :: Elem ] ,
201+ b: & mut [ Self :: Elem ] ,
202+ ) -> Result <LeastSquaresOwned <Self :: Elem >> {
203+ let LeastSquaresRef { rank, .. } = self . calc( a, b) ?;
204+ let singular_values = unsafe { self . singular_values. assume_init( ) } ;
205+ Ok ( LeastSquaresOwned {
206+ singular_values,
207+ rank,
208+ } )
209+ }
210+ }
211+ } ;
212+ }
213+ impl_least_squares_work_c ! ( c64, lapack_sys:: zgelsd_) ;
214+ impl_least_squares_work_c ! ( c32, lapack_sys:: cgelsd_) ;
215+
216+ macro_rules! impl_least_squares_work_r {
217+ ( $c: ty, $lsd: path) => {
218+ impl LeastSquaresWorkImpl for LeastSquaresWork <$c> {
219+ type Elem = $c;
220+
221+ fn new( a_layout: MatrixLayout , b_layout: MatrixLayout ) -> Result <Self > {
222+ let ( m, n) = a_layout. size( ) ;
223+ let ( m_, nrhs) = b_layout. size( ) ;
224+ let k = m. min( n) ;
225+ assert!( m_ >= m) ;
226+
227+ let rcond = -1. ;
228+ let mut singular_values = vec_uninit( k as usize ) ;
229+ let mut rank: i32 = 0 ;
230+
231+ // eval work size
232+ let mut info = 0 ;
233+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
234+ let mut iwork_size = [ 0 ] ;
235+ unsafe {
236+ $lsd(
237+ & m,
238+ & n,
239+ & nrhs,
240+ std:: ptr:: null_mut( ) ,
241+ & a_layout. lda( ) ,
242+ std:: ptr:: null_mut( ) ,
243+ & b_layout. lda( ) ,
244+ AsPtr :: as_mut_ptr( & mut singular_values) ,
245+ & rcond,
246+ & mut rank,
247+ AsPtr :: as_mut_ptr( & mut work_size) ,
248+ & ( -1 ) ,
249+ iwork_size. as_mut_ptr( ) ,
250+ & mut info,
251+ )
252+ } ;
253+ info. as_lapack_result( ) ?;
254+
255+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
256+ let liwork = iwork_size[ 0 ] . to_usize( ) . unwrap( ) ;
257+
258+ let work = vec_uninit( lwork) ;
259+ let iwork = vec_uninit( liwork) ;
260+
261+ Ok ( LeastSquaresWork {
262+ a_layout,
263+ b_layout,
264+ work,
265+ iwork,
266+ rwork: None ,
267+ singular_values,
268+ } )
269+ }
270+
271+ fn calc(
272+ & mut self ,
273+ a: & mut [ Self :: Elem ] ,
274+ b: & mut [ Self :: Elem ] ,
275+ ) -> Result <LeastSquaresRef <Self :: Elem >> {
276+ let ( m, n) = self . a_layout. size( ) ;
277+ let ( m_, nrhs) = self . b_layout. size( ) ;
278+ assert!( m_ >= m) ;
279+
280+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
281+
282+ // Transpose if a is C-continuous
283+ let mut a_t = None ;
284+ let a_layout = match self . a_layout {
285+ MatrixLayout :: C { .. } => {
286+ let ( layout, t) = transpose( self . a_layout, a) ;
287+ a_t = Some ( t) ;
288+ layout
289+ }
290+ MatrixLayout :: F { .. } => self . a_layout,
291+ } ;
292+
293+ // Transpose if b is C-continuous
294+ let mut b_t = None ;
295+ let b_layout = match self . b_layout {
296+ MatrixLayout :: C { .. } => {
297+ let ( layout, t) = transpose( self . b_layout, b) ;
298+ b_t = Some ( t) ;
299+ layout
300+ }
301+ MatrixLayout :: F { .. } => self . b_layout,
302+ } ;
303+
304+ let rcond: <Self :: Elem as Scalar >:: Real = -1. ;
305+ let mut rank: i32 = 0 ;
306+
307+ let mut info = 0 ;
308+ unsafe {
309+ $lsd(
310+ & m,
311+ & n,
312+ & nrhs,
313+ AsPtr :: as_mut_ptr( a_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( a) ) ,
314+ & a_layout. lda( ) ,
315+ AsPtr :: as_mut_ptr( b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ) ,
316+ & b_layout. lda( ) ,
317+ AsPtr :: as_mut_ptr( & mut self . singular_values) ,
318+ & rcond,
319+ & mut rank,
320+ AsPtr :: as_mut_ptr( & mut self . work) ,
321+ & lwork,
322+ AsPtr :: as_mut_ptr( & mut self . iwork) ,
323+ & mut info,
324+ ) ;
325+ }
326+ info. as_lapack_result( ) ?;
327+
328+ let singular_values = unsafe { self . singular_values. slice_assume_init_ref( ) } ;
329+
330+ // Skip a_t -> a transpose because A has been destroyed
331+ // Re-transpose b
332+ if let Some ( b_t) = b_t {
333+ transpose_over( b_layout, & b_t, b) ;
334+ }
335+
336+ Ok ( LeastSquaresRef {
337+ singular_values,
338+ rank,
339+ } )
340+ }
341+
342+ fn eval(
343+ mut self ,
344+ a: & mut [ Self :: Elem ] ,
345+ b: & mut [ Self :: Elem ] ,
346+ ) -> Result <LeastSquaresOwned <Self :: Elem >> {
347+ let LeastSquaresRef { rank, .. } = self . calc( a, b) ?;
348+ let singular_values = unsafe { self . singular_values. assume_init( ) } ;
349+ Ok ( LeastSquaresOwned {
350+ singular_values,
351+ rank,
352+ } )
353+ }
354+ }
355+ } ;
33356}
357+ impl_least_squares_work_r ! ( f64 , lapack_sys:: dgelsd_) ;
358+ impl_least_squares_work_r ! ( f32 , lapack_sys:: sgelsd_) ;
34359
35360macro_rules! impl_least_squares {
36361 ( @real, $scalar: ty, $gelsd: path) => {
0 commit comments