@@ -107,7 +107,50 @@ def test_setitem(shape, data):
107107 ph .assert_0d_equals ("__setitem__" , "value" , value , f"x[{ key } ]" , res [key ])
108108
109109
110- # TODO: test boolean indexing
110+ # TODO: make mask tests optional
111+
112+
113+ @given (hh .shapes (), st .data ())
114+ def test_getitem_mask (shape , data ):
115+ x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
116+ mask_shapes = st .one_of (
117+ st .sampled_from ([x .shape , ()]),
118+ st .lists (st .booleans (), min_size = x .ndim , max_size = x .ndim ).map (
119+ lambda l : tuple (s if b else 0 for s , b in zip (x .shape , l ))
120+ ),
121+ hh .shapes (),
122+ )
123+ key = data .draw (xps .arrays (dtype = xp .bool , shape = mask_shapes ), label = "key" )
124+
125+ if key .ndim > x .ndim or not all (
126+ ks in (xs , 0 ) for xs , ks in zip (x .shape , key .shape )
127+ ):
128+ with pytest .raises (IndexError ):
129+ x [key ]
130+ return
131+
132+ out = x [key ]
133+
134+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
135+ if key .ndim == 0 :
136+ out_shape = (1 ,) if key else (0 ,)
137+ out_shape += x .shape
138+ else :
139+ size = int (xp .sum (xp .astype (key , xp .uint8 )))
140+ out_shape = (size ,) + x .shape [key .ndim :]
141+ ph .assert_shape ("__getitem__" , out .shape , out_shape )
142+
143+
144+ @given (hh .shapes (min_side = 1 ), st .data ())
145+ def test_setitem_mask (shape , data ):
146+ x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
147+ key = data .draw (xps .arrays (dtype = xp .bool , shape = shape ), label = "key" )
148+ value = data .draw (xps .from_dtype (x .dtype ), label = "value" )
149+
150+ res = xp .asarray (x , copy = True )
151+ res [key ] = value
152+
153+ # TODO
111154
112155
113156def make_param (method_name : str , dtype : DataType , stype : ScalarType ) -> Param :
0 commit comments