@@ -488,7 +488,7 @@ def device_ptr(self):
488488 Note
489489 ----
490490 - This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
491- - No other arrays will share the same device pointer.
491+ - No other arrays will share the same device pointer.
492492 - A copy of the memory is done if multiple arrays share the same memory or the array is not the owner of the memory.
493493 - In case of a copy the return value points to the newly allocated memory which is now exclusively owned by the array.
494494 """
@@ -985,6 +985,12 @@ def __getitem__(self, key):
985985 try :
986986 out = Array ()
987987 n_dims = self .numdims ()
988+
989+ if (isinstance (key , Array ) and key .type () == Dtype .b8 .value ):
990+ n_dims = 1
991+ if (count (key ) == 0 ):
992+ return out
993+
988994 inds = _get_indices (key )
989995
990996 safe_call (backend .get ().af_index_gen (ct .pointer (out .arr ),
@@ -1005,9 +1011,21 @@ def __setitem__(self, key, val):
10051011 try :
10061012 n_dims = self .numdims ()
10071013
1014+ is_boolean_idx = isinstance (key , Array ) and key .type () == Dtype .b8 .value
1015+
1016+ if (is_boolean_idx ):
1017+ n_dims = 1
1018+ num = count (key )
1019+ if (num == 0 ):
1020+ return
1021+
10081022 if (_is_number (val )):
10091023 tdims = _get_assign_dims (key , self .dims ())
1010- other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1024+ if (is_boolean_idx ):
1025+ n_dims = 1
1026+ other_arr = constant_array (val , int (num ), dtype = self .type ())
1027+ else :
1028+ other_arr = constant_array (val , tdims [0 ] , tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
10111029 del_other = True
10121030 else :
10131031 other_arr = val .arr
@@ -1017,8 +1035,8 @@ def __setitem__(self, key, val):
10171035 inds = _get_indices (key )
10181036
10191037 safe_call (backend .get ().af_assign_gen (ct .pointer (out_arr ),
1020- self .arr , ct .c_longlong (n_dims ), inds .pointer ,
1021- other_arr ))
1038+ self .arr , ct .c_longlong (n_dims ), inds .pointer ,
1039+ other_arr ))
10221040 safe_call (backend .get ().af_release_array (self .arr ))
10231041 if del_other :
10241042 safe_call (backend .get ().af_release_array (other_arr ))
@@ -1235,5 +1253,5 @@ def read_array(filename, index=None, key=None):
12351253
12361254 return out
12371255
1238- from .algorithm import sum
1256+ from .algorithm import ( sum , count )
12391257from .arith import cast
0 commit comments