Skip to content

Commit da82001

Browse files
committed
Correctly calculate dimensions for assignments with boolean keys
1 parent c3a273e commit da82001

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

arrayfire/array.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .base import *
2020
from .index import *
2121
from .index import _Index4
22+
from .algorithm import sum
2223

2324
def _create_array(buf, numdims, idims, dtype):
2425
out_arr = ct.c_void_p(0)
@@ -182,8 +183,12 @@ def _get_assign_dims(key, idims):
182183
elif isinstance(key, ParallelRange):
183184
dims[0] = _slice_to_length(key.S, idims[0])
184185
return dims
185-
elif isinstance(key, BaseArray):
186-
dims[0] = key.elements()
186+
elif isinstance(key, BaseArray):
187+
# If the array is boolean take only the number of nonzeros
188+
if(key.dtype() is Dtype.b8):
189+
dims[0] = int(sum(key))
190+
else:
191+
dims[0] = key.elements()
187192
return dims
188193
elif isinstance(key, tuple):
189194
n_inds = len(key)
@@ -192,7 +197,11 @@ def _get_assign_dims(key, idims):
192197
if (_is_number(key[n])):
193198
dims[n] = 1
194199
elif (isinstance(key[n], BaseArray)):
195-
dims[n] = key[n].elements()
200+
# If the array is boolean take only the number of nonzeros
201+
if(key[n].dtype() is Dtype.b8):
202+
dims[n] = int(sum(key[n]))
203+
else:
204+
dims[n] = key[n].elements()
196205
elif (isinstance(key[n], slice)):
197206
dims[n] = _slice_to_length(key[n], idims[n])
198207
elif (isinstance(key[n], ParallelRange)):

0 commit comments

Comments
 (0)