|
| 1 | +//! Compute singular value decomposition with divide-and-conquer algorithm |
| 2 | +//! |
| 3 | +//! LAPACK correspondance |
| 4 | +//! ---------------------- |
| 5 | +//! |
| 6 | +//! | f32 | f64 | c32 | c64 | |
| 7 | +//! |:-------|:-------|:-------|:-------| |
| 8 | +//! | sgesdd | dgesdd | cgesdd | zgesdd | |
| 9 | +//! |
| 10 | +
|
1 | 11 | use crate::{error::*, layout::MatrixLayout, *}; |
2 | 12 | use cauchy::*; |
3 | 13 | use num_traits::{ToPrimitive, Zero}; |
4 | 14 |
|
5 | | -#[cfg_attr(doc, katexit::katexit)] |
6 | | -/// Singular value decomposition with divide-and-conquer method |
7 | | -pub trait SVDDC_: Scalar { |
8 | | - /// Compute singular value decomposition $A = U \Sigma V^T$ |
9 | | - /// |
10 | | - /// LAPACK correspondance |
11 | | - /// ---------------------- |
12 | | - /// |
13 | | - /// | f32 | f64 | c32 | c64 | |
14 | | - /// |:-------|:-------|:-------|:-------| |
15 | | - /// | sgesdd | dgesdd | cgesdd | zgesdd | |
16 | | - /// |
17 | | - fn svddc(layout: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SvdOwned<Self>>; |
18 | | -} |
19 | | - |
20 | 15 | pub struct SvdDcWork<T: Scalar> { |
21 | 16 | pub jobz: JobSvd, |
22 | 17 | pub layout: MatrixLayout, |
@@ -310,111 +305,3 @@ macro_rules! impl_svd_dc_work_r { |
310 | 305 | } |
311 | 306 | impl_svd_dc_work_r!(f64, lapack_sys::dgesdd_); |
312 | 307 | impl_svd_dc_work_r!(f32, lapack_sys::sgesdd_); |
313 | | - |
314 | | -macro_rules! impl_svddc { |
315 | | - (@real, $scalar:ty, $gesdd:path) => { |
316 | | - impl_svddc!(@body, $scalar, $gesdd, ); |
317 | | - }; |
318 | | - (@complex, $scalar:ty, $gesdd:path) => { |
319 | | - impl_svddc!(@body, $scalar, $gesdd, rwork); |
320 | | - }; |
321 | | - (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { |
322 | | - impl SVDDC_ for $scalar { |
323 | | - fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self],) -> Result<SvdOwned<Self>> { |
324 | | - let m = l.lda(); |
325 | | - let n = l.len(); |
326 | | - let k = m.min(n); |
327 | | - let mut s = vec_uninit(k as usize); |
328 | | - |
329 | | - let (u_col, vt_row) = match jobz { |
330 | | - JobSvd::All | JobSvd::None => (m, n), |
331 | | - JobSvd::Some => (k, k), |
332 | | - }; |
333 | | - let (mut u, mut vt) = match jobz { |
334 | | - JobSvd::All => ( |
335 | | - Some(vec_uninit((m * m) as usize)), |
336 | | - Some(vec_uninit((n * n) as usize)), |
337 | | - ), |
338 | | - JobSvd::Some => ( |
339 | | - Some(vec_uninit((m * u_col) as usize)), |
340 | | - Some(vec_uninit((n * vt_row) as usize)), |
341 | | - ), |
342 | | - JobSvd::None => (None, None), |
343 | | - }; |
344 | | - |
345 | | - $( // for complex only |
346 | | - let mx = n.max(m) as usize; |
347 | | - let mn = n.min(m) as usize; |
348 | | - let lrwork = match jobz { |
349 | | - JobSvd::None => 7 * mn, |
350 | | - _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), |
351 | | - }; |
352 | | - let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(lrwork); |
353 | | - )* |
354 | | - |
355 | | - // eval work size |
356 | | - let mut info = 0; |
357 | | - let mut iwork: Vec<MaybeUninit<i32>> = vec_uninit(8 * k as usize); |
358 | | - let mut work_size = [Self::zero()]; |
359 | | - unsafe { |
360 | | - $gesdd( |
361 | | - jobz.as_ptr(), |
362 | | - &m, |
363 | | - &n, |
364 | | - AsPtr::as_mut_ptr(a), |
365 | | - &m, |
366 | | - AsPtr::as_mut_ptr(&mut s), |
367 | | - AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), |
368 | | - &m, |
369 | | - AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), |
370 | | - &vt_row, |
371 | | - AsPtr::as_mut_ptr(&mut work_size), |
372 | | - &(-1), |
373 | | - $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* |
374 | | - AsPtr::as_mut_ptr(&mut iwork), |
375 | | - &mut info, |
376 | | - ); |
377 | | - } |
378 | | - info.as_lapack_result()?; |
379 | | - |
380 | | - // do svd |
381 | | - let lwork = work_size[0].to_usize().unwrap(); |
382 | | - let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork); |
383 | | - unsafe { |
384 | | - $gesdd( |
385 | | - jobz.as_ptr(), |
386 | | - &m, |
387 | | - &n, |
388 | | - AsPtr::as_mut_ptr(a), |
389 | | - &m, |
390 | | - AsPtr::as_mut_ptr(&mut s), |
391 | | - AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), |
392 | | - &m, |
393 | | - AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])), |
394 | | - &vt_row, |
395 | | - AsPtr::as_mut_ptr(&mut work), |
396 | | - &(lwork as i32), |
397 | | - $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* |
398 | | - AsPtr::as_mut_ptr(&mut iwork), |
399 | | - &mut info, |
400 | | - ); |
401 | | - } |
402 | | - info.as_lapack_result()?; |
403 | | - |
404 | | - let s = unsafe { s.assume_init() }; |
405 | | - let u = u.map(|v| unsafe { v.assume_init() }); |
406 | | - let vt = vt.map(|v| unsafe { v.assume_init() }); |
407 | | - |
408 | | - match l { |
409 | | - MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }), |
410 | | - MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }), |
411 | | - } |
412 | | - } |
413 | | - } |
414 | | - }; |
415 | | -} |
416 | | - |
417 | | -impl_svddc!(@real, f32, lapack_sys::sgesdd_); |
418 | | -impl_svddc!(@real, f64, lapack_sys::dgesdd_); |
419 | | -impl_svddc!(@complex, c32, lapack_sys::cgesdd_); |
420 | | -impl_svddc!(@complex, c64, lapack_sys::zgesdd_); |
0 commit comments