Skip to content

Commit 0a64db6

Browse files
committed
Merge pull request #20 from 9prady9/function_overloads
Function overloads for binary operations
2 parents 21fc3e6 + 9232326 commit 0a64db6

File tree

6 files changed

+116
-30
lines changed

6 files changed

+116
-30
lines changed

build.conf

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{
2-
"use_backend": "cuda",
2+
"use_backend": "cpu",
33

44
"use_lib": false,
55
"lib_dir": "/usr/local/lib",
66
"inc_dir": "/usr/local/include",
77

88
"build_type": "Release",
99
"build_threads": "4",
10-
"build_cuda": "ON",
10+
"build_cuda": "OFF",
1111
"build_opencl": "ON",
1212
"build_cpu": "ON",
1313
"build_examples": "OFF",
@@ -28,7 +28,7 @@
2828
"glew_dir": "E:\\Libraries\\GLEW",
2929
"glfw_dir": "E:\\Libraries\\glfw3",
3030
"boost_dir": "E:\\Libraries\\boost_1_56_0",
31-
31+
3232
"cuda_sdk": "/usr/local/cuda",
3333
"opencl_sdk": "/usr",
3434
"sdk_lib_dir": "lib"

examples/helloworld.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ extern crate arrayfire as af;
33
use af::Dim4;
44
use af::Array;
55

6+
#[allow(unused_must_use)]
67
fn main() {
78
af::set_device(0);
89
af::info();
@@ -14,10 +15,9 @@ fn main() {
1415
af::print(&a);
1516

1617
println!("Element-wise arithmetic");
17-
let sin_res = af::sin(&a).unwrap();
18-
let cos_res = af::cos(&a).unwrap();
19-
let b = &sin_res + 1.5;
20-
let b2 = &sin_res + &cos_res;
18+
let b = af::add(af::sin(&a), 1.5).unwrap();
19+
let b2 = af::add(af::sin(&a), af::cos(&a)).unwrap();
20+
2121
let b3 = ! &a;
2222
println!("sin(a) + 1.5 => "); af::print(&b);
2323
println!("sin(a) + cos(a) => "); af::print(&b2);

src/arith/mod.rs

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
extern crate libc;
22
extern crate num;
33

4+
use dim4::Dim4;
45
use array::Array;
56
use defines::AfError;
67
use self::libc::{c_int};
7-
use data::constant;
8+
use data::{constant, tile};
89
use self::num::Complex;
910

1011
type MutAfArray = *mut self::libc::c_longlong;
@@ -182,32 +183,100 @@ macro_rules! binary_func {
182183
)
183184
}
184185

185-
binary_func!(add, af_add);
186-
binary_func!(sub, af_sub);
187-
binary_func!(mul, af_mul);
188-
binary_func!(div, af_div);
189-
binary_func!(rem, af_rem);
190186
binary_func!(bitand, af_bitand);
191187
binary_func!(bitor, af_bitor);
192188
binary_func!(bitxor, af_bitxor);
193-
binary_func!(shiftl, af_bitshiftl);
194-
binary_func!(shiftr, af_bitshiftr);
195-
binary_func!(lt, af_lt);
196-
binary_func!(gt, af_gt);
197-
binary_func!(le, af_le);
198-
binary_func!(ge, af_ge);
199-
binary_func!(eq, af_eq);
200189
binary_func!(neq, af_neq);
201190
binary_func!(and, af_and);
202191
binary_func!(or, af_or);
203192
binary_func!(minof, af_minof);
204193
binary_func!(maxof, af_maxof);
205-
binary_func!(modulo, af_mod);
206194
binary_func!(hypot, af_hypot);
207-
binary_func!(atan2, af_atan2);
208-
binary_func!(cplx2, af_cplx2);
209-
binary_func!(root, af_root);
210-
binary_func!(pow, af_pow);
195+
196+
pub trait Convertable {
197+
fn convert(&self) -> Array;
198+
}
199+
200+
macro_rules! convertable_type_def {
201+
($rust_type: ty) => (
202+
impl Convertable for $rust_type {
203+
fn convert(&self) -> Array {
204+
constant(*self, Dim4::new(&[1,1,1,1])).unwrap()
205+
}
206+
}
207+
)
208+
}
209+
210+
convertable_type_def!(f64);
211+
convertable_type_def!(f32);
212+
convertable_type_def!(i32);
213+
convertable_type_def!(u32);
214+
convertable_type_def!(u8);
215+
216+
impl Convertable for Array {
217+
fn convert(&self) -> Array {
218+
self.clone()
219+
}
220+
}
221+
222+
impl Convertable for Result<Array, AfError> {
223+
fn convert(&self) -> Array {
224+
self.clone().unwrap()
225+
}
226+
}
227+
228+
macro_rules! overloaded_binary_func {
229+
($fn_name: ident, $help_name: ident, $ffi_name: ident) => (
230+
fn $help_name(lhs: &Array, rhs: &Array) -> Result<Array, AfError> {
231+
unsafe {
232+
let mut temp: i64 = 0;
233+
let err_val = $ffi_name(&mut temp as MutAfArray,
234+
lhs.get() as AfArray, rhs.get() as AfArray,
235+
0);
236+
match err_val {
237+
0 => Ok(Array::from(temp)),
238+
_ => Err(AfError::from(err_val)),
239+
}
240+
}
241+
}
242+
243+
pub fn $fn_name<T: Convertable, U: Convertable> (arg1: T, arg2: U) -> Result<Array, AfError> {
244+
let lhs = arg1.convert();
245+
let rhs = arg2.convert();
246+
match (lhs.is_scalar().unwrap(), rhs.is_scalar().unwrap()) {
247+
( true, false) => {
248+
let l = tile(&lhs, rhs.dims().unwrap()).unwrap();
249+
$help_name(&l, &rhs)
250+
},
251+
(false, true) => {
252+
let r = tile(&rhs, lhs.dims().unwrap()).unwrap();
253+
$help_name(&lhs, &r)
254+
},
255+
_ => $help_name(&lhs, &rhs),
256+
}
257+
}
258+
)
259+
}
260+
261+
// thanks to Umar Arshad for the idea on how to
262+
// implement overloaded function
263+
overloaded_binary_func!(add, add_helper, af_add);
264+
overloaded_binary_func!(sub, sub_helper, af_sub);
265+
overloaded_binary_func!(mul, mul_helper, af_mul);
266+
overloaded_binary_func!(div, div_helper, af_div);
267+
overloaded_binary_func!(rem, rem_helper, af_rem);
268+
overloaded_binary_func!(shiftl, shiftl_helper, af_bitshiftl);
269+
overloaded_binary_func!(shiftr, shiftr_helper, af_bitshiftr);
270+
overloaded_binary_func!(lt, lt_helper, af_lt);
271+
overloaded_binary_func!(gt, gt_helper, af_gt);
272+
overloaded_binary_func!(le, le_helper, af_le);
273+
overloaded_binary_func!(ge, ge_helper, af_ge);
274+
overloaded_binary_func!(eq, eq_helper, af_eq);
275+
overloaded_binary_func!(modulo, modulo_helper, af_mod);
276+
overloaded_binary_func!(atan2, atan2_helper, af_atan2);
277+
overloaded_binary_func!(cplx2, cplx2_helper, af_cplx2);
278+
overloaded_binary_func!(root, root_helper, af_root);
279+
overloaded_binary_func!(pow, pow_helper, af_pow);
211280

212281
macro_rules! arith_scalar_func {
213282
($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => (

src/array.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ extern {
5757

5858
fn af_retain_array(out: MutAfArray, arr: AfArray) -> c_int;
5959

60+
fn af_copy_array(out: MutAfArray, arr: AfArray) -> c_int;
61+
6062
fn af_release_array(arr: AfArray) -> c_int;
6163

6264
fn af_print_array(arr: AfArray) -> c_int;
@@ -171,6 +173,17 @@ impl Array {
171173
}
172174
}
173175

176+
pub fn copy(&self) -> Result<Array, AfError> {
177+
unsafe {
178+
let mut temp: i64 = 0;
179+
let err_val = af_copy_array(&mut temp as MutAfArray, self.handle as AfArray);
180+
match err_val {
181+
0 => Ok(Array::from(temp)),
182+
_ => Err(AfError::from(err_val)),
183+
}
184+
}
185+
}
186+
174187
is_func!(is_empty, af_is_empty);
175188
is_func!(is_scalar, af_is_scalar);
176189
is_func!(is_row, af_is_row);

src/data/mod.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,16 @@ impl ConstGenerator for Complex<f64> {
133133

134134
#[allow(unused_mut)]
135135
impl ConstGenerator for bool {
136-
fn generate(&self, dims: Dim4) -> Array {
136+
fn generate(&self, dims: Dim4) -> Result<Array, AfError> {
137137
unsafe {
138138
let mut temp: i64 = 0;
139-
af_constant(&mut temp as MutAfArray, *self as c_int as c_double,
140-
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT, 4);
141-
Array::from(temp)
139+
let err_val = af_constant(&mut temp as MutAfArray, *self as c_int as c_double,
140+
dims.ndims() as c_uint,
141+
dims.get().as_ptr() as *const DimT, 4);
142+
match err_val {
143+
0 => Ok(Array::from(temp)),
144+
_ => Err(AfError::from(err_val)),
145+
}
142146
}
143147
}
144148
}

src/dim4.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl Dim4 {
3939
let nelems = self.elements();
4040
match nelems {
4141
0 => 0,
42-
1 => 0,
42+
1 => 1,
4343
_ => {
4444
if self.dims[3] != 1 { 4 }
4545
else if self.dims[2] != 1 { 3 }

0 commit comments

Comments
 (0)