@@ -2,13 +2,8 @@ use ndarray::*;
22use ndarray_linalg:: * ;
33use std:: cmp:: min;
44
5- fn test ( a : & Array2 < f64 > , n : usize , m : usize ) {
6- test_both ( a, n, m) ;
7- test_u ( a, n, m) ;
8- test_vt ( a, n, m) ;
9- }
10-
11- fn test_both ( a : & Array2 < f64 > , n : usize , m : usize ) {
5+ fn test ( a : & Array2 < f64 > ) {
6+ let ( n, m) = a. dim ( ) ;
127 let answer = a. clone ( ) ;
138 println ! ( "a = \n {:?}" , a) ;
149 let ( u, s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , true ) . unwrap ( ) ;
@@ -24,7 +19,8 @@ fn test_both(a: &Array2<f64>, n: usize, m: usize) {
2419 assert_close_l2 ! ( & u. dot( & sm) . dot( & vt) , & answer, 1e-7 ) ;
2520}
2621
27- fn test_u ( a : & Array2 < f64 > , n : usize , _m : usize ) {
22+ fn test_u ( a : & Array2 < f64 > ) {
23+ let ( n, _m) = a. dim ( ) ;
2824 println ! ( "a = \n {:?}" , a) ;
2925 let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , false ) . unwrap ( ) ;
3026 assert ! ( u. is_some( ) ) ;
@@ -34,7 +30,8 @@ fn test_u(a: &Array2<f64>, n: usize, _m: usize) {
3430 assert_eq ! ( u. dim( ) . 1 , n) ;
3531}
3632
37- fn test_vt ( a : & Array2 < f64 > , _n : usize , m : usize ) {
33+ fn test_vt ( a : & Array2 < f64 > ) {
34+ let ( _n, m) = a. dim ( ) ;
3835 println ! ( "a = \n {:?}" , a) ;
3936 let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( false , true ) . unwrap ( ) ;
4037 assert ! ( u. is_none( ) ) ;
@@ -44,38 +41,30 @@ fn test_vt(a: &Array2<f64>, _n: usize, m: usize) {
4441 assert_eq ! ( vt. dim( ) . 1 , m) ;
4542}
4643
47- #[ test]
48- fn svd_square ( ) {
49- let a = random ( ( 3 , 3 ) ) ;
50- test ( & a, 3 , 3 ) ;
51- }
52-
53- #[ test]
54- fn svd_square_t ( ) {
55- let a = random ( ( 3 , 3 ) . f ( ) ) ;
56- test ( & a, 3 , 3 ) ;
57- }
58-
59- #[ test]
60- fn svd_3x4 ( ) {
61- let a = random ( ( 3 , 4 ) ) ;
62- test ( & a, 3 , 4 ) ;
63- }
44+ macro_rules! test_svd_impl {
45+ ( $test: ident, $n: expr, $m: expr) => {
46+ paste:: item! {
47+ #[ test]
48+ fn [ <svd_ $test _ $n x $m>] ( ) {
49+ let a = random( ( $n, $m) ) ;
50+ $test( & a) ;
51+ }
6452
65- #[ test]
66- fn svd_3x4_t ( ) {
67- let a = random ( ( 3 , 4 ) . f ( ) ) ;
68- test ( & a, 3 , 4 ) ;
53+ #[ test]
54+ fn [ <svd_ $test _ $n x $m _t>] ( ) {
55+ let a = random( ( $n, $m) . f( ) ) ;
56+ $test( & a) ;
57+ }
58+ }
59+ } ;
6960}
7061
71- #[ test]
72- fn svd_4x3 ( ) {
73- let a = random ( ( 4 , 3 ) ) ;
74- test ( & a, 4 , 3 ) ;
75- }
76-
77- #[ test]
78- fn svd_4x3_t ( ) {
79- let a = random ( ( 4 , 3 ) . f ( ) ) ;
80- test ( & a, 4 , 3 ) ;
81- }
62+ test_svd_impl ! ( test, 3 , 3 ) ;
63+ test_svd_impl ! ( test_u, 3 , 3 ) ;
64+ test_svd_impl ! ( test_vt, 3 , 3 ) ;
65+ test_svd_impl ! ( test, 4 , 3 ) ;
66+ test_svd_impl ! ( test_u, 4 , 3 ) ;
67+ test_svd_impl ! ( test_vt, 4 , 3 ) ;
68+ test_svd_impl ! ( test, 3 , 4 ) ;
69+ test_svd_impl ! ( test_u, 3 , 4 ) ;
70+ test_svd_impl ! ( test_vt, 3 , 4 ) ;
0 commit comments