@@ -7,6 +7,8 @@ use ndarray_parallel::prelude::*;
77
88const M : usize = 1024 * 10 ;
99const N : usize = 100 ;
10+ const CHUNK_SIZE : usize = 100 ;
11+ const N_CHUNKS : usize = ( M + CHUNK_SIZE - 1 ) / CHUNK_SIZE ;
1012
1113#[ test]
1214fn test_axis_iter ( ) {
@@ -53,3 +55,32 @@ fn test_regular_iter_collect() {
5355 let v = a. view ( ) . into_par_iter ( ) . map ( |& x| x) . collect :: < Vec < _ > > ( ) ;
5456 assert_eq ! ( v. len( ) , a. len( ) ) ;
5557}
58+
59+ #[ test]
60+ fn test_axis_chunks_iter ( ) {
61+ let mut a = Array2 :: < f64 > :: zeros ( ( M , N ) ) ;
62+ for ( i, mut v) in a. axis_chunks_iter_mut ( Axis ( 0 ) , CHUNK_SIZE ) . enumerate ( ) {
63+ v. fill ( i as _ ) ;
64+ }
65+ assert_eq ! ( a. axis_chunks_iter( Axis ( 0 ) , CHUNK_SIZE ) . len( ) , N_CHUNKS ) ;
66+ let s: f64 = a
67+ . axis_chunks_iter ( Axis ( 0 ) , CHUNK_SIZE )
68+ . into_par_iter ( )
69+ . map ( |x| x. sum ( ) )
70+ . sum ( ) ;
71+ println ! ( "{:?}" , a. slice( s![ ..10 , ..5 ] ) ) ;
72+ assert_eq ! ( s, a. sum( ) ) ;
73+ }
74+
75+ #[ test]
76+ fn test_axis_chunks_iter_mut ( ) {
77+ let mut a = Array :: linspace ( 0. , 1.0f64 , M * N )
78+ . into_shape ( ( M , N ) )
79+ . unwrap ( ) ;
80+ let b = a. mapv ( |x| x. exp ( ) ) ;
81+ a. axis_chunks_iter_mut ( Axis ( 0 ) , CHUNK_SIZE )
82+ . into_par_iter ( )
83+ . for_each ( |mut v| v. mapv_inplace ( |x| x. exp ( ) ) ) ;
84+ println ! ( "{:?}" , a. slice( s![ ..10 , ..5 ] ) ) ;
85+ assert ! ( a. all_close( & b, 0.001 ) ) ;
86+ }
0 commit comments