|
1 | 1 | //! Singular-value decomposition |
| 2 | +//! |
| 3 | +//! LAPACK correspondance |
| 4 | +//! ---------------------- |
| 5 | +//! |
| 6 | +//! | f32 | f64 | c32 | c64 | |
| 7 | +//! |:-------|:-------|:-------|:-------| |
| 8 | +//! | sgesvd | dgesvd | cgesvd | zgesvd | |
| 9 | +//! |
2 | 10 |
|
3 | 11 | use super::{error::*, layout::*, *}; |
4 | 12 | use cauchy::*; |
5 | 13 | use num_traits::{ToPrimitive, Zero}; |
6 | 14 |
|
7 | | -/// Result of SVD |
8 | | -pub struct SVDOutput<A: Scalar> { |
9 | | - /// diagonal values |
10 | | - pub s: Vec<A::Real>, |
11 | | - /// Unitary matrix for destination space |
12 | | - pub u: Option<Vec<A>>, |
13 | | - /// Unitary matrix for departure space |
14 | | - pub vt: Option<Vec<A>>, |
15 | | -} |
16 | | - |
17 | | -#[cfg_attr(doc, katexit::katexit)] |
18 | | -/// Singular value decomposition |
19 | | -pub trait SVD_: Scalar { |
20 | | - /// Compute singular value decomposition $A = U \Sigma V^T$ |
21 | | - /// |
22 | | - /// LAPACK correspondance |
23 | | - /// ---------------------- |
24 | | - /// |
25 | | - /// | f32 | f64 | c32 | c64 | |
26 | | - /// |:-------|:-------|:-------|:-------| |
27 | | - /// | sgesvd | dgesvd | cgesvd | zgesvd | |
28 | | - /// |
29 | | - fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) |
30 | | - -> Result<SVDOutput<Self>>; |
31 | | -} |
32 | | - |
33 | 15 | pub struct SvdWork<T: Scalar> { |
34 | 16 | pub ju: JobSvd, |
35 | 17 | pub jvt: JobSvd, |
@@ -330,109 +312,3 @@ macro_rules! impl_svd_work_r { |
330 | 312 | } |
331 | 313 | impl_svd_work_r!(f64, lapack_sys::dgesvd_); |
332 | 314 | impl_svd_work_r!(f32, lapack_sys::sgesvd_); |
333 | | - |
334 | | -macro_rules! impl_svd { |
335 | | - (@real, $scalar:ty, $gesvd:path) => { |
336 | | - impl_svd!(@body, $scalar, $gesvd, ); |
337 | | - }; |
338 | | - (@complex, $scalar:ty, $gesvd:path) => { |
339 | | - impl_svd!(@body, $scalar, $gesvd, rwork); |
340 | | - }; |
341 | | - (@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => { |
342 | | - impl SVD_ for $scalar { |
343 | | - fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result<SVDOutput<Self>> { |
344 | | - let ju = match l { |
345 | | - MatrixLayout::F { .. } => JobSvd::from_bool(calc_u), |
346 | | - MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt), |
347 | | - }; |
348 | | - let jvt = match l { |
349 | | - MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt), |
350 | | - MatrixLayout::C { .. } => JobSvd::from_bool(calc_u), |
351 | | - }; |
352 | | - |
353 | | - let m = l.lda(); |
354 | | - let mut u = match ju { |
355 | | - JobSvd::All => Some(vec_uninit( (m * m) as usize)), |
356 | | - JobSvd::None => None, |
357 | | - _ => unimplemented!("SVD with partial vector output is not supported yet") |
358 | | - }; |
359 | | - |
360 | | - let n = l.len(); |
361 | | - let mut vt = match jvt { |
362 | | - JobSvd::All => Some(vec_uninit( (n * n) as usize)), |
363 | | - JobSvd::None => None, |
364 | | - _ => unimplemented!("SVD with partial vector output is not supported yet") |
365 | | - }; |
366 | | - |
367 | | - let k = std::cmp::min(m, n); |
368 | | - let mut s = vec_uninit( k as usize); |
369 | | - |
370 | | - $( |
371 | | - let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(5 * k as usize); |
372 | | - )* |
373 | | - |
374 | | - // eval work size |
375 | | - let mut info = 0; |
376 | | - let mut work_size = [Self::zero()]; |
377 | | - unsafe { |
378 | | - $gesvd( |
379 | | - ju.as_ptr(), |
380 | | - jvt.as_ptr(), |
381 | | - &m, |
382 | | - &n, |
383 | | - AsPtr::as_mut_ptr(a), |
384 | | - &m, |
385 | | - AsPtr::as_mut_ptr(&mut s), |
386 | | - AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), |
387 | | - &m, |
388 | | - AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), |
389 | | - &n, |
390 | | - AsPtr::as_mut_ptr(&mut work_size), |
391 | | - &(-1), |
392 | | - $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* |
393 | | - &mut info, |
394 | | - ); |
395 | | - } |
396 | | - info.as_lapack_result()?; |
397 | | - |
398 | | - // calc |
399 | | - let lwork = work_size[0].to_usize().unwrap(); |
400 | | - let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork); |
401 | | - unsafe { |
402 | | - $gesvd( |
403 | | - ju.as_ptr(), |
404 | | - jvt.as_ptr() , |
405 | | - &m, |
406 | | - &n, |
407 | | - AsPtr::as_mut_ptr(a), |
408 | | - &m, |
409 | | - AsPtr::as_mut_ptr(&mut s), |
410 | | - AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), |
411 | | - &m, |
412 | | - AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), |
413 | | - &n, |
414 | | - AsPtr::as_mut_ptr(&mut work), |
415 | | - &(lwork as i32), |
416 | | - $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* |
417 | | - &mut info, |
418 | | - ); |
419 | | - } |
420 | | - info.as_lapack_result()?; |
421 | | - |
422 | | - let s = unsafe { s.assume_init() }; |
423 | | - let u = u.map(|v| unsafe { v.assume_init() }); |
424 | | - let vt = vt.map(|v| unsafe { v.assume_init() }); |
425 | | - |
426 | | - match l { |
427 | | - MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), |
428 | | - MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), |
429 | | - } |
430 | | - } |
431 | | - } |
432 | | - }; |
433 | | -} // impl_svd! |
434 | | - |
435 | | -impl_svd!(@real, f64, lapack_sys::dgesvd_); |
436 | | -impl_svd!(@real, f32, lapack_sys::sgesvd_); |
437 | | -impl_svd!(@complex, c64, lapack_sys::zgesvd_); |
438 | | -impl_svd!(@complex, c32, lapack_sys::cgesvd_); |
0 commit comments