@@ -2,7 +2,8 @@ use ndarray::*;
22use ndarray_linalg:: * ;
33use std:: cmp:: min;
44
5- fn test ( a : & Array2 < f64 > , n : usize , m : usize ) {
5+ fn test ( a : & Array2 < f64 > ) {
6+ let ( n, m) = a. dim ( ) ;
67 let answer = a. clone ( ) ;
78 println ! ( "a = \n {:?}" , a) ;
89 let ( u, s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , true ) . unwrap ( ) ;
@@ -18,38 +19,62 @@ fn test(a: &Array2<f64>, n: usize, m: usize) {
1819 assert_close_l2 ! ( & u. dot( & sm) . dot( & vt) , & answer, 1e-7 ) ;
1920}
2021
21- #[ test]
22- fn svd_square ( ) {
23- let a = random ( ( 3 , 3 ) ) ;
24- test ( & a, 3 , 3 ) ;
22+ fn test_no_vt ( a : & Array2 < f64 > ) {
23+ let ( n, _m) = a. dim ( ) ;
24+ println ! ( "a = \n {:?}" , a) ;
25+ let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , false ) . unwrap ( ) ;
26+ assert ! ( u. is_some( ) ) ;
27+ assert ! ( vt. is_none( ) ) ;
28+ let u = u. unwrap ( ) ;
29+ assert_eq ! ( u. dim( ) . 0 , n) ;
30+ assert_eq ! ( u. dim( ) . 1 , n) ;
2531}
2632
27- #[ test]
28- fn svd_square_t ( ) {
29- let a = random ( ( 3 , 3 ) . f ( ) ) ;
30- test ( & a, 3 , 3 ) ;
33+ fn test_no_u ( a : & Array2 < f64 > ) {
34+ let ( _n, m) = a. dim ( ) ;
35+ println ! ( "a = \n {:?}" , a) ;
36+ let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( false , true ) . unwrap ( ) ;
37+ assert ! ( u. is_none( ) ) ;
38+ assert ! ( vt. is_some( ) ) ;
39+ let vt = vt. unwrap ( ) ;
40+ assert_eq ! ( vt. dim( ) . 0 , m) ;
41+ assert_eq ! ( vt. dim( ) . 1 , m) ;
3142}
3243
33- #[ test]
34- fn svd_3x4 ( ) {
35- let a = random ( ( 3 , 4 ) ) ;
36- test ( & a, 3 , 4 ) ;
44+ fn test_diag_only ( a : & Array2 < f64 > ) {
45+ println ! ( "a = \n {:?}" , a) ;
46+ let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( false , false ) . unwrap ( ) ;
47+ assert ! ( u. is_none( ) ) ;
48+ assert ! ( vt. is_none( ) ) ;
3749}
3850
39- #[ test]
40- fn svd_3x4_t ( ) {
41- let a = random ( ( 3 , 4 ) . f ( ) ) ;
42- test ( & a, 3 , 4 ) ;
43- }
51+ macro_rules! test_svd_impl {
52+ ( $test: ident, $n: expr, $m: expr) => {
53+ paste:: item! {
54+ #[ test]
55+ fn [ <svd_ $test _ $n x $m>] ( ) {
56+ let a = random( ( $n, $m) ) ;
57+ $test( & a) ;
58+ }
4459
45- #[ test]
46- fn svd_4x3 ( ) {
47- let a = random ( ( 4 , 3 ) ) ;
48- test ( & a, 4 , 3 ) ;
60+ #[ test]
61+ fn [ <svd_ $test _ $n x $m _t>] ( ) {
62+ let a = random( ( $n, $m) . f( ) ) ;
63+ $test( & a) ;
64+ }
65+ }
66+ } ;
4967}
5068
51- #[ test]
52- fn svd_4x3_t ( ) {
53- let a = random ( ( 4 , 3 ) . f ( ) ) ;
54- test ( & a, 4 , 3 ) ;
55- }
69+ test_svd_impl ! ( test, 3 , 3 ) ;
70+ test_svd_impl ! ( test_no_vt, 3 , 3 ) ;
71+ test_svd_impl ! ( test_no_u, 3 , 3 ) ;
72+ test_svd_impl ! ( test_diag_only, 3 , 3 ) ;
73+ test_svd_impl ! ( test, 4 , 3 ) ;
74+ test_svd_impl ! ( test_no_vt, 4 , 3 ) ;
75+ test_svd_impl ! ( test_no_u, 4 , 3 ) ;
76+ test_svd_impl ! ( test_diag_only, 4 , 3 ) ;
77+ test_svd_impl ! ( test, 3 , 4 ) ;
78+ test_svd_impl ! ( test_no_vt, 3 , 4 ) ;
79+ test_svd_impl ! ( test_no_u, 3 , 4 ) ;
80+ test_svd_impl ! ( test_diag_only, 3 , 4 ) ;
0 commit comments