Skip to content

Commit 551d054

Browse files
committed
FEATURE: clamp function in arith module
1 parent e63eb0c commit 551d054

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

src/arith/mod.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ extern {
4242
fn af_bitshiftr(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
4343
fn af_minof(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
4444
fn af_maxof(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
45+
fn af_clamp(out: MutAfArray, inp: AfArray, lo: AfArray, hi: AfArray, batch: c_int) -> c_int;
4546

4647
fn af_not(out: MutAfArray, arr: AfArray) -> c_int;
4748
fn af_abs(out: MutAfArray, arr: AfArray) -> c_int;
@@ -293,6 +294,35 @@ overloaded_binary_func!("Create complex array from two Arrays", cplx2, cplx2_hel
293294
overloaded_binary_func!("Compute root", root, root_helper, af_root);
294295
overloaded_binary_func!("Computer power", pow, pow_helper, af_pow);
295296

297+
pub fn clamp<T, U> (input: &Array, arg1: &T, arg2: &U, batch: bool) -> Array
298+
where T: Convertable, U: Convertable
299+
{
300+
let clamp_helper = |lo: &Array, hi: &Array| {
301+
unsafe {
302+
let mut temp: i64 = 0;
303+
let err_val = af_clamp(&mut temp as MutAfArray, input.get() as AfArray,
304+
lo.get() as AfArray, hi.get() as AfArray,
305+
batch as c_int);
306+
HANDLE_ERROR(AfError::from(err_val));
307+
Array::from(temp)
308+
}
309+
};
310+
311+
let lo = arg1.convert();
312+
let hi = arg2.convert();
313+
match (lo.is_scalar(), hi.is_scalar()) {
314+
( true, false) => {
315+
let l = tile(&lo, hi.dims());
316+
clamp_helper(&l, &hi)
317+
},
318+
(false, true) => {
319+
let r = tile(&hi, lo.dims());
320+
clamp_helper(&lo, &r)
321+
},
322+
_ => clamp_helper(&lo, &hi),
323+
}
324+
}
325+
296326
macro_rules! arith_scalar_func {
297327
($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => (
298328
impl<'f> $op_name<$rust_type> for &'f Array {

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ mod algorithm;
1919

2020
pub use arith::{add, sub, div, mul, lt, gt, le, ge, eq, neq, and, or, minof, maxof, rem};
2121
pub use arith::{bitand, bitor, bitxor, shiftl, shiftr};
22-
pub use arith::{abs, sign, round, trunc, floor, ceil, modulo, sigmoid};
22+
pub use arith::{abs, sign, round, trunc, floor, ceil, modulo, sigmoid, clamp};
2323
pub use arith::{sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh};
2424
pub use arith::{atan2, cplx2, arg, cplx, real, imag, conjg, hypot};
2525
pub use arith::{sqrt, log, log1p, log10, log2, pow2, exp, expm1, erf, erfc, root, pow};

0 commit comments

Comments
 (0)