@@ -22,12 +22,29 @@ def _parallel_dim(a, dim, c_func):
2222def _reduce_all (a , c_func ):
2323 real = ct .c_double (0 )
2424 imag = ct .c_double (0 )
25+
2526 safe_call (c_func (ct .pointer (real ), ct .pointer (imag ), a .arr ))
27+
2628 real = real .value
2729 imag = imag .value
2830 return real if imag == 0 else real + imag * 1j
2931
30- def sum (a , dim = None ):
32+ def _nan_parallel_dim (a , dim , c_func , nan_val ):
33+ out = Array ()
34+ safe_call (c_func (ct .pointer (out .arr ), a .arr , ct .c_int (dim ), ct .c_double (nan_val )))
35+ return out
36+
37+ def _nan_reduce_all (a , c_func , nan_val ):
38+ real = ct .c_double (0 )
39+ imag = ct .c_double (0 )
40+
41+ safe_call (c_func (ct .pointer (real ), ct .pointer (imag ), a .arr , ct .c_double (nan_val )))
42+
43+ real = real .value
44+ imag = imag .value
45+ return real if imag == 0 else real + imag * 1j
46+
47+ def sum (a , dim = None , nan_val = None ):
3148 """
3249 Calculate the sum of all the elements along a specified dimension.
3350
@@ -37,19 +54,27 @@ def sum(a, dim=None):
3754 Multi dimensional arrayfire array.
3855 dim: optional: int. default: None
3956 Dimension along which the sum is required.
57+ nan_val: optional: scalar. default: None
58+ The value that replaces NaN in the array
4059
4160 Returns
4261 -------
4362 out: af.Array or scalar number
4463 The sum of all elements in `a` along dimension `dim`.
4564 If `dim` is `None`, sum of the entire Array is returned.
4665 """
47- if dim is not None :
48- return _parallel_dim (a , dim , backend .get ().af_sum )
66+ if (nan_val is not None ):
67+ if dim is not None :
68+ return _nan_parallel_dim (a , dim , backend .get ().af_sum_nan , nan_val )
69+ else :
70+ return _nan_reduce_all (a , backend .get ().af_sum_nan_all , nan_val )
4971 else :
50- return _reduce_all (a , backend .get ().af_sum_all )
72+ if dim is not None :
73+ return _parallel_dim (a , dim , backend .get ().af_sum )
74+ else :
75+ return _reduce_all (a , backend .get ().af_sum_all )
5176
52- def product (a , dim = None ):
77+ def product (a , dim = None , nan_val = None ):
5378 """
5479 Calculate the product of all the elements along a specified dimension.
5580
@@ -59,17 +84,25 @@ def product(a, dim=None):
5984 Multi dimensional arrayfire array.
6085 dim: optional: int. default: None
6186 Dimension along which the product is required.
87+ nan_val: optional: scalar. default: None
88+ The value that replaces NaN in the array
6289
6390 Returns
6491 -------
6592 out: af.Array or scalar number
6693 The product of all elements in `a` along dimension `dim`.
6794 If `dim` is `None`, product of the entire Array is returned.
6895 """
69- if dim is not None :
70- return _parallel_dim (a , dim , backend .get ().af_product )
96+ if (nan_val is not None ):
97+ if dim is not None :
98+ return _nan_parallel_dim (a , dim , backend .get ().af_product_nan , nan_val )
99+ else :
100+ return _nan_reduce_all (a , backend .get ().af_product_nan_all , nan_val )
71101 else :
72- return _reduce_all (a , backend .get ().af_product_all )
102+ if dim is not None :
103+ return _parallel_dim (a , dim , backend .get ().af_product )
104+ else :
105+ return _reduce_all (a , backend .get ().af_product_all )
73106
74107def min (a , dim = None ):
75108 """
0 commit comments