Skip to content

Commit 3a64cac

Browse files
committed
Adding support for replacing nan values for reductions
1 parent 4c3a2a6 commit 3a64cac

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

arrayfire/algorithm.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,29 @@ def _parallel_dim(a, dim, c_func):
2222
def _reduce_all(a, c_func):
2323
real = ct.c_double(0)
2424
imag = ct.c_double(0)
25+
2526
safe_call(c_func(ct.pointer(real), ct.pointer(imag), a.arr))
27+
2628
real = real.value
2729
imag = imag.value
2830
return real if imag == 0 else real + imag * 1j
2931

30-
def sum(a, dim=None):
32+
def _nan_parallel_dim(a, dim, c_func, nan_val):
33+
out = Array()
34+
safe_call(c_func(ct.pointer(out.arr), a.arr, ct.c_int(dim), ct.c_double(nan_val)))
35+
return out
36+
37+
def _nan_reduce_all(a, c_func, nan_val):
38+
real = ct.c_double(0)
39+
imag = ct.c_double(0)
40+
41+
safe_call(c_func(ct.pointer(real), ct.pointer(imag), a.arr, ct.c_double(nan_val)))
42+
43+
real = real.value
44+
imag = imag.value
45+
return real if imag == 0 else real + imag * 1j
46+
47+
def sum(a, dim=None, nan_val=None):
3148
"""
3249
Calculate the sum of all the elements along a specified dimension.
3350
@@ -37,19 +54,27 @@ def sum(a, dim=None):
3754
Multi dimensional arrayfire array.
3855
dim: optional: int. default: None
3956
Dimension along which the sum is required.
57+
nan_val: optional: scalar. default: None
58+
The value that replaces NaN in the array
4059
4160
Returns
4261
-------
4362
out: af.Array or scalar number
4463
The sum of all elements in `a` along dimension `dim`.
4564
If `dim` is `None`, sum of the entire Array is returned.
4665
"""
47-
if dim is not None:
48-
return _parallel_dim(a, dim, backend.get().af_sum)
66+
if (nan_val is not None):
67+
if dim is not None:
68+
return _nan_parallel_dim(a, dim, backend.get().af_sum_nan, nan_val)
69+
else:
70+
return _nan_reduce_all(a, backend.get().af_sum_nan_all, nan_val)
4971
else:
50-
return _reduce_all(a, backend.get().af_sum_all)
72+
if dim is not None:
73+
return _parallel_dim(a, dim, backend.get().af_sum)
74+
else:
75+
return _reduce_all(a, backend.get().af_sum_all)
5176

52-
def product(a, dim=None):
77+
def product(a, dim=None, nan_val=None):
5378
"""
5479
Calculate the product of all the elements along a specified dimension.
5580
@@ -59,17 +84,25 @@ def product(a, dim=None):
5984
Multi dimensional arrayfire array.
6085
dim: optional: int. default: None
6186
Dimension along which the product is required.
87+
nan_val: optional: scalar. default: None
88+
The value that replaces NaN in the array
6289
6390
Returns
6491
-------
6592
out: af.Array or scalar number
6693
The product of all elements in `a` along dimension `dim`.
6794
If `dim` is `None`, product of the entire Array is returned.
6895
"""
69-
if dim is not None:
70-
return _parallel_dim(a, dim, backend.get().af_product)
96+
if (nan_val is not None):
97+
if dim is not None:
98+
return _nan_parallel_dim(a, dim, backend.get().af_product_nan, nan_val)
99+
else:
100+
return _nan_reduce_all(a, backend.get().af_product_nan_all, nan_val)
71101
else:
72-
return _reduce_all(a, backend.get().af_product_all)
102+
if dim is not None:
103+
return _parallel_dim(a, dim, backend.get().af_product)
104+
else:
105+
return _reduce_all(a, backend.get().af_product_all)
73106

74107
def min(a, dim=None):
75108
"""

tests/simple/algorithm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ def simple_algorithm(verbose = False):
4747
display_func(af.sort(a, is_ascending=True))
4848
display_func(af.sort(a, is_ascending=False))
4949

50+
b = (a > 0.1) * a
51+
c = (a > 0.4) * a
52+
d = b / c
53+
print_func(af.sum(d));
54+
print_func(af.sum(d, nan_val=0.0));
55+
display_func(af.sum(d, dim=0, nan_val=0.0));
56+
5057
val,idx = af.sort_index(a, is_ascending=True)
5158
display_func(val)
5259
display_func(idx)

0 commit comments

Comments
 (0)