@@ -180,12 +180,12 @@ macro_rules! binary_func {
180180 ///
181181 /// This is an element wise binary operation.
182182 #[ allow( unused_mut) ]
183- pub fn $fn_name( lhs: & Array , rhs: & Array ) -> Array {
183+ pub fn $fn_name( lhs: & Array , rhs: & Array , batch : bool ) -> Array {
184184 unsafe {
185185 let mut temp: i64 = 0 ;
186186 let err_val = $ffi_fn( & mut temp as MutAfArray ,
187187 lhs. get( ) as AfArray , rhs. get( ) as AfArray ,
188- 0 ) ;
188+ batch as c_int ) ;
189189 HANDLE_ERROR ( AfError :: from( err_val) ) ;
190190 Array :: from( temp)
191191 }
@@ -217,6 +217,8 @@ macro_rules! convertable_type_def {
217217 )
218218}
219219
220+ convertable_type_def ! ( Complex <f64 >) ;
221+ convertable_type_def ! ( Complex <f32 >) ;
220222convertable_type_def ! ( u64 ) ;
221223convertable_type_def ! ( i64 ) ;
222224convertable_type_def ! ( f64 ) ;
@@ -350,45 +352,33 @@ pub fn clamp<T, U> (input: &Array, arg1: &T, arg2: &U, batch: bool) -> Array
350352}
351353
352354macro_rules! arith_scalar_func {
353- ( $rust_type: ty, $op_name: ident, $fn_name: ident, $ffi_fn : ident ) => (
355+ ( $rust_type: ty, $op_name: ident, $fn_name: ident) => (
354356 impl <' f> $op_name<$rust_type> for & ' f Array {
355357 type Output = Array ;
356358
357359 fn $fn_name( self , rhs: $rust_type) -> Array {
358- let cnst_arr = constant( rhs, self . dims( ) ) ;
359- unsafe {
360- let mut temp: i64 = 0 ;
361- let err_val = $ffi_fn( & mut temp as MutAfArray , self . get( ) as AfArray ,
362- cnst_arr. get( ) as AfArray , 0 ) ;
363- HANDLE_ERROR ( AfError :: from( err_val) ) ;
364- Array :: from( temp)
365- }
360+ let temp = rhs. clone( ) ;
361+ $fn_name( self , & temp, false )
366362 }
367363 }
368364
369365 impl $op_name<$rust_type> for Array {
370366 type Output = Array ;
371367
372368 fn $fn_name( self , rhs: $rust_type) -> Array {
373- let cnst_arr = constant( rhs, self . dims( ) ) ;
374- unsafe {
375- let mut temp: i64 = 0 ;
376- let err_val = $ffi_fn( & mut temp as MutAfArray , self . get( ) as AfArray ,
377- cnst_arr. get( ) as AfArray , 0 ) ;
378- HANDLE_ERROR ( AfError :: from( err_val) ) ;
379- Array :: from( temp)
380- }
369+ let temp = rhs. clone( ) ;
370+ $fn_name( & self , & temp, false )
381371 }
382372 }
383373 )
384374}
385375
386376macro_rules! arith_scalar_spec {
387377 ( $ty_name: ty) => (
388- arith_scalar_func!( $ty_name, Add , add, af_add ) ;
389- arith_scalar_func!( $ty_name, Sub , sub, af_sub ) ;
390- arith_scalar_func!( $ty_name, Mul , mul, af_mul ) ;
391- arith_scalar_func!( $ty_name, Div , div, af_div ) ;
378+ arith_scalar_func!( $ty_name, Add , add) ;
379+ arith_scalar_func!( $ty_name, Sub , sub) ;
380+ arith_scalar_func!( $ty_name, Mul , mul) ;
381+ arith_scalar_func!( $ty_name, Div , div) ;
392382 )
393383}
394384
@@ -403,33 +393,51 @@ arith_scalar_spec!(i32);
403393arith_scalar_spec ! ( u8 ) ;
404394
405395macro_rules! arith_func {
406- ( $op_name: ident, $fn_name: ident, $ffi_fn : ident) => (
396+ ( $op_name: ident, $fn_name: ident, $delegate : ident) => (
407397 impl $op_name<Array > for Array {
408398 type Output = Array ;
409399
410400 fn $fn_name( self , rhs: Array ) -> Array {
411- unsafe {
412- let mut temp: i64 = 0 ;
413- let err_val = $ffi_fn( & mut temp as MutAfArray ,
414- self . get( ) as AfArray , rhs. get( ) as AfArray , 0 ) ;
415- HANDLE_ERROR ( AfError :: from( err_val) ) ;
416- Array :: from( temp)
417- }
401+ $delegate( & self , & rhs, false )
402+ }
403+ }
404+
405+ impl <' a> $op_name<& ' a Array > for Array {
406+ type Output = Array ;
407+
408+ fn $fn_name( self , rhs: & ' a Array ) -> Array {
409+ $delegate( & self , rhs, false )
410+ }
411+ }
412+
413+ impl <' a> $op_name<Array > for & ' a Array {
414+ type Output = Array ;
415+
416+ fn $fn_name( self , rhs: Array ) -> Array {
417+ $delegate( self , & rhs, false )
418+ }
419+ }
420+
421+ impl <' a, ' b> $op_name<& ' a Array > for & ' b Array {
422+ type Output = Array ;
423+
424+ fn $fn_name( self , rhs: & ' a Array ) -> Array {
425+ $delegate( self , rhs, false )
418426 }
419427 }
420428 )
421429}
422430
423- arith_func ! ( Add , add, af_add ) ;
424- arith_func ! ( Sub , sub, af_sub ) ;
425- arith_func ! ( Mul , mul, af_mul ) ;
426- arith_func ! ( Div , div, af_div ) ;
427- arith_func ! ( Rem , rem, af_rem ) ;
428- arith_func ! ( BitAnd , bitand , af_bitand ) ;
429- arith_func ! ( BitOr , bitor , af_bitor ) ;
430- arith_func ! ( BitXor , bitxor , af_bitxor ) ;
431- arith_func ! ( Shl , shl , af_bitshiftl ) ;
432- arith_func ! ( Shr , shr , af_bitshiftr ) ;
431+ arith_func ! ( Add , add , add ) ;
432+ arith_func ! ( Sub , sub , sub ) ;
433+ arith_func ! ( Mul , mul , mul ) ;
434+ arith_func ! ( Div , div , div ) ;
435+ arith_func ! ( Rem , rem , rem ) ;
436+ arith_func ! ( Shl , shl , shiftl ) ;
437+ arith_func ! ( Shr , shr , shiftr ) ;
438+ arith_func ! ( BitAnd , bitand , bitand ) ;
439+ arith_func ! ( BitOr , bitor , bitor ) ;
440+ arith_func ! ( BitXor , bitxor , bitxor ) ;
433441
434442#[ cfg( op_assign) ]
435443mod op_assign {
@@ -477,7 +485,7 @@ macro_rules! bit_assign_func {
477485 let mut idxrs = Indexer :: new( ) ;
478486 idxrs. set_index( & Seq :: <f32 >:: default ( ) , 0 , Some ( false ) ) ;
479487 idxrs. set_index( & Seq :: <f32 >:: default ( ) , 1 , Some ( false ) ) ;
480- let tmp = assign_gen( self as & Array , & idxrs, & $func( self as & Array , & rhs) ) ;
488+ let tmp = assign_gen( self as & Array , & idxrs, & $func( self as & Array , & rhs, false ) ) ;
481489 mem:: replace( self , tmp) ;
482490 }
483491 }
0 commit comments