Skip to content

Commit d369255

Browse files
committed
Added conditional operator functions in data module
1 parent df5e2cc commit d369255

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed

src/data/mod.rs

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ extern {
5959
fn af_flip(out: MutAfArray, arr: AfArray, dim: c_uint) -> c_int;
6060
fn af_lower(out: MutAfArray, arr: AfArray, is_unit_diag: c_int) -> c_int;
6161
fn af_upper(out: MutAfArray, arr: AfArray, is_unit_diag: c_int) -> c_int;
62+
63+
fn af_select(out: MutAfArray, cond: AfArray, a: AfArray, b: AfArray) -> c_int;
64+
fn af_select_scalar_l(out: MutAfArray, cond: AfArray, a: c_double, b: AfArray) -> c_int;
65+
fn af_select_scalar_r(out: MutAfArray, cond: AfArray, a: AfArray, b: c_double) -> c_int;
66+
67+
fn af_replace(a: AfArray, cond: AfArray, b: AfArray) -> c_int;
68+
fn af_replace_scalar(a: AfArray, cond: AfArray, b: c_double) -> c_int;
6269
}
6370

6471
pub trait ConstGenerator {
@@ -555,3 +562,162 @@ pub fn upper(input: &Array, is_unit_diag: bool) -> Result<Array, AfError> {
555562
}
556563
}
557564
}
565+
566+
/// Element wise conditional operator for Arrays
567+
///
568+
/// This function does the C-equivalent of the following statement, except that the operation
569+
/// happens on a GPU for all elements simultaneously.
570+
///
571+
/// ```
572+
/// c = cond ? a : b; /// where cond, a & b are all objects of type Array
573+
/// ```
574+
///
575+
/// # Parameters
576+
///
577+
/// - `a` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
578+
/// `True`
579+
/// - `cond` is the Array with conditional values
580+
/// - `b` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
581+
/// `False`
582+
///
583+
/// # Return Values
584+
///
585+
/// An Array
586+
#[allow(unused_mut)]
587+
pub fn select(a: &Array, cond: &Array, b: &Array) -> Result<Array, AfError> {
588+
unsafe {
589+
let mut temp: i64 = 0;
590+
let err_val = af_select(&mut temp as MutAfArray, cond.get() as AfArray,
591+
a.get() as AfArray, b.get() as AfArray);
592+
match err_val {
593+
0 => Ok(Array::from(temp)),
594+
_ => Err(AfError::from(err_val)),
595+
}
596+
}
597+
}
598+
599+
/// Element wise conditional operator for Arrays
600+
///
601+
/// This function does the C-equivalent of the following statement, except that the operation
602+
/// happens on a GPU for all elements simultaneously.
603+
///
604+
/// ```
605+
/// c = cond ? a : b; /// where a is a scalar(f64) and b is Array
606+
/// ```
607+
///
608+
/// # Parameters
609+
///
610+
/// - `a` is the scalar that is assigned to output if corresponding element in `cond` Array is
611+
/// `True`
612+
/// - `cond` is the Array with conditional values
613+
/// - `b` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
614+
/// `False`
615+
///
616+
/// # Return Values
617+
///
618+
/// An Array
619+
#[allow(unused_mut)]
620+
pub fn selectl(a: f64, cond: &Array, b: &Array) -> Result<Array, AfError> {
621+
unsafe {
622+
let mut temp: i64 = 0;
623+
let err_val = af_select_scalar_l(&mut temp as MutAfArray, cond.get() as AfArray,
624+
a as c_double, b.get() as AfArray);
625+
match err_val {
626+
0 => Ok(Array::from(temp)),
627+
_ => Err(AfError::from(err_val)),
628+
}
629+
}
630+
}
631+
632+
/// Element wise conditional operator for Arrays
633+
///
634+
/// This function does the C-equivalent of the following statement, except that the operation
635+
/// happens on a GPU for all elements simultaneously.
636+
///
637+
/// ```
638+
/// c = cond ? a : b; /// where a is Array and b is a scalar(f64)
639+
/// ```
640+
///
641+
/// # Parameters
642+
///
643+
/// - `a` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
644+
/// `True`
645+
/// - `cond` is the Array with conditional values
646+
/// - `b` is the scalar that is assigned to output if corresponding element in `cond` Array is
647+
/// `False`
648+
///
649+
/// # Return Values
650+
///
651+
/// An Array
652+
#[allow(unused_mut)]
653+
pub fn selectr(a: &Array, cond: &Array, b: f64) -> Result<Array, AfError> {
654+
unsafe {
655+
let mut temp: i64 = 0;
656+
let err_val = af_select_scalar_r(&mut temp as MutAfArray, cond.get() as AfArray,
657+
a.get() as AfArray, b as c_double);
658+
match err_val {
659+
0 => Ok(Array::from(temp)),
660+
_ => Err(AfError::from(err_val)),
661+
}
662+
}
663+
}
664+
665+
/// Inplace replace in Array based on a condition
666+
///
667+
/// This function does the C-equivalent of the following statement, except that the operation
668+
/// happens on a GPU for all elements simultaneously.
669+
///
670+
/// ```
671+
/// a = cond ? a : b; /// where cond, a & b are all objects of type Array
672+
/// ```
673+
///
674+
/// # Parameters
675+
///
676+
/// - `a` is the Array whose element will be replaced with element from `b` if corresponding element in `cond` Array is `True`
677+
/// - `cond` is the Array with conditional values
678+
/// - `b` is the Array whose element will replace the element in output if corresponding element in `cond` Array is
679+
/// `False`
680+
///
681+
/// # Return Values
682+
///
683+
/// An Array
684+
#[allow(unused_mut)]
685+
pub fn replace(a: &mut Array, cond: &Array, b: &Array) -> Result<(), AfError> {
686+
unsafe {
687+
let err_val = af_replace(a.get() as AfArray, cond.get() as AfArray, b.get() as AfArray);
688+
match err_val {
689+
0 => Ok(()),
690+
_ => Err(AfError::from(err_val)),
691+
}
692+
}
693+
}
694+
695+
/// Inplace replace in Array based on a condition
696+
///
697+
/// This function does the C-equivalent of the following statement, except that the operation
698+
/// happens on a GPU for all elements simultaneously.
699+
///
700+
/// ```
701+
/// a = cond ? a : b; /// where cond, a are Arrays and b is scalar(f64)
702+
/// ```
703+
///
704+
/// # Parameters
705+
///
706+
/// - `a` is the Array whose element will be replaced with element from `b` if corresponding element in `cond` Array is `True`
707+
/// - `cond` is the Array with conditional values
708+
/// - `b` is the scalar that will replace the element in output if corresponding element in `cond` Array is
709+
/// `False`
710+
///
711+
/// # Return Values
712+
///
713+
/// An Array
714+
#[allow(unused_mut)]
715+
pub fn replace_scalar(a: &mut Array, cond: &Array, b: f64) -> Result<(), AfError> {
716+
unsafe {
717+
let err_val = af_replace_scalar(a.get() as AfArray, cond.get() as AfArray, b as c_double);
718+
match err_val {
719+
0 => Ok(()),
720+
_ => Err(AfError::from(err_val)),
721+
}
722+
}
723+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub use data::{set_seed, get_seed, randu, randn};
3131
pub use data::{identity, diag_create, diag_extract, lower, upper};
3232
pub use data::{join, join_many, tile};
3333
pub use data::{reorder, shift, moddims, flat, flip};
34+
pub use data::{select, selectl, selectr, replace, replace_scalar};
3435
mod data;
3536

3637
pub use device::{get_version, info, device_count, is_double_available, set_device, get_device, sync};

0 commit comments

Comments
 (0)