@@ -59,6 +59,13 @@ extern {
5959 fn af_flip ( out : MutAfArray , arr : AfArray , dim : c_uint ) -> c_int ;
6060 fn af_lower ( out : MutAfArray , arr : AfArray , is_unit_diag : c_int ) -> c_int ;
6161 fn af_upper ( out : MutAfArray , arr : AfArray , is_unit_diag : c_int ) -> c_int ;
62+
63+ fn af_select ( out : MutAfArray , cond : AfArray , a : AfArray , b : AfArray ) -> c_int ;
64+ fn af_select_scalar_l ( out : MutAfArray , cond : AfArray , a : c_double , b : AfArray ) -> c_int ;
65+ fn af_select_scalar_r ( out : MutAfArray , cond : AfArray , a : AfArray , b : c_double ) -> c_int ;
66+
67+ fn af_replace ( a : AfArray , cond : AfArray , b : AfArray ) -> c_int ;
68+ fn af_replace_scalar ( a : AfArray , cond : AfArray , b : c_double ) -> c_int ;
6269}
6370
6471pub trait ConstGenerator {
@@ -555,3 +562,162 @@ pub fn upper(input: &Array, is_unit_diag: bool) -> Result<Array, AfError> {
555562 }
556563 }
557564}
565+
566+ /// Element wise conditional operator for Arrays
567+ ///
568+ /// This function does the C-equivalent of the following statement, except that the operation
569+ /// happens on a GPU for all elements simultaneously.
570+ ///
571+ /// ```
572+ /// c = cond ? a : b; /// where cond, a & b are all objects of type Array
573+ /// ```
574+ ///
575+ /// # Parameters
576+ ///
577+ /// - `a` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
578+ /// `True`
579+ /// - `cond` is the Array with conditional values
580+ /// - `b` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
581+ /// `False`
582+ ///
583+ /// # Return Values
584+ ///
585+ /// An Array
586+ #[ allow( unused_mut) ]
587+ pub fn select ( a : & Array , cond : & Array , b : & Array ) -> Result < Array , AfError > {
588+ unsafe {
589+ let mut temp: i64 = 0 ;
590+ let err_val = af_select ( & mut temp as MutAfArray , cond. get ( ) as AfArray ,
591+ a. get ( ) as AfArray , b. get ( ) as AfArray ) ;
592+ match err_val {
593+ 0 => Ok ( Array :: from ( temp) ) ,
594+ _ => Err ( AfError :: from ( err_val) ) ,
595+ }
596+ }
597+ }
598+
599+ /// Element wise conditional operator for Arrays
600+ ///
601+ /// This function does the C-equivalent of the following statement, except that the operation
602+ /// happens on a GPU for all elements simultaneously.
603+ ///
604+ /// ```
605+ /// c = cond ? a : b; /// where a is a scalar(f64) and b is Array
606+ /// ```
607+ ///
608+ /// # Parameters
609+ ///
610+ /// - `a` is the scalar that is assigned to output if corresponding element in `cond` Array is
611+ /// `True`
612+ /// - `cond` is the Array with conditional values
613+ /// - `b` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
614+ /// `False`
615+ ///
616+ /// # Return Values
617+ ///
618+ /// An Array
619+ #[ allow( unused_mut) ]
620+ pub fn selectl ( a : f64 , cond : & Array , b : & Array ) -> Result < Array , AfError > {
621+ unsafe {
622+ let mut temp: i64 = 0 ;
623+ let err_val = af_select_scalar_l ( & mut temp as MutAfArray , cond. get ( ) as AfArray ,
624+ a as c_double , b. get ( ) as AfArray ) ;
625+ match err_val {
626+ 0 => Ok ( Array :: from ( temp) ) ,
627+ _ => Err ( AfError :: from ( err_val) ) ,
628+ }
629+ }
630+ }
631+
632+ /// Element wise conditional operator for Arrays
633+ ///
634+ /// This function does the C-equivalent of the following statement, except that the operation
635+ /// happens on a GPU for all elements simultaneously.
636+ ///
637+ /// ```
638+ /// c = cond ? a : b; /// where a is Array and b is a scalar(f64)
639+ /// ```
640+ ///
641+ /// # Parameters
642+ ///
643+ /// - `a` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
644+ /// `True`
645+ /// - `cond` is the Array with conditional values
646+ /// - `b` is the scalar that is assigned to output if corresponding element in `cond` Array is
647+ /// `False`
648+ ///
649+ /// # Return Values
650+ ///
651+ /// An Array
652+ #[ allow( unused_mut) ]
653+ pub fn selectr ( a : & Array , cond : & Array , b : f64 ) -> Result < Array , AfError > {
654+ unsafe {
655+ let mut temp: i64 = 0 ;
656+ let err_val = af_select_scalar_r ( & mut temp as MutAfArray , cond. get ( ) as AfArray ,
657+ a. get ( ) as AfArray , b as c_double ) ;
658+ match err_val {
659+ 0 => Ok ( Array :: from ( temp) ) ,
660+ _ => Err ( AfError :: from ( err_val) ) ,
661+ }
662+ }
663+ }
664+
665+ /// Inplace replace in Array based on a condition
666+ ///
667+ /// This function does the C-equivalent of the following statement, except that the operation
668+ /// happens on a GPU for all elements simultaneously.
669+ ///
670+ /// ```
671+ /// a = cond ? a : b; /// where cond, a & b are all objects of type Array
672+ /// ```
673+ ///
674+ /// # Parameters
675+ ///
676+ /// - `a` is the Array whose element will be replaced with element from `b` if corresponding element in `cond` Array is `True`
677+ /// - `cond` is the Array with conditional values
678+ /// - `b` is the Array whose element will replace the element in output if corresponding element in `cond` Array is
679+ /// `False`
680+ ///
681+ /// # Return Values
682+ ///
683+ /// An Array
684+ #[ allow( unused_mut) ]
685+ pub fn replace ( a : & mut Array , cond : & Array , b : & Array ) -> Result < ( ) , AfError > {
686+ unsafe {
687+ let err_val = af_replace ( a. get ( ) as AfArray , cond. get ( ) as AfArray , b. get ( ) as AfArray ) ;
688+ match err_val {
689+ 0 => Ok ( ( ) ) ,
690+ _ => Err ( AfError :: from ( err_val) ) ,
691+ }
692+ }
693+ }
694+
695+ /// Inplace replace in Array based on a condition
696+ ///
697+ /// This function does the C-equivalent of the following statement, except that the operation
698+ /// happens on a GPU for all elements simultaneously.
699+ ///
700+ /// ```
701+ /// a = cond ? a : b; /// where cond, a are Arrays and b is scalar(f64)
702+ /// ```
703+ ///
704+ /// # Parameters
705+ ///
706+ /// - `a` is the Array whose element will be replaced with element from `b` if corresponding element in `cond` Array is `True`
707+ /// - `cond` is the Array with conditional values
708+ /// - `b` is the scalar that will replace the element in output if corresponding element in `cond` Array is
709+ /// `False`
710+ ///
711+ /// # Return Values
712+ ///
713+ /// An Array
714+ #[ allow( unused_mut) ]
715+ pub fn replace_scalar ( a : & mut Array , cond : & Array , b : f64 ) -> Result < ( ) , AfError > {
716+ unsafe {
717+ let err_val = af_replace_scalar ( a. get ( ) as AfArray , cond. get ( ) as AfArray , b as c_double ) ;
718+ match err_val {
719+ 0 => Ok ( ( ) ) ,
720+ _ => Err ( AfError :: from ( err_val) ) ,
721+ }
722+ }
723+ }
0 commit comments