@@ -156,16 +156,40 @@ def test_getitem_mask(shape, data):
156156 )
157157
158158
159- @given (hh .shapes (min_side = 1 ), st .data ())
159+ @given (hh .shapes (), st .data ())
160160def test_setitem_mask (shape , data ):
161161 x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
162162 key = data .draw (xps .arrays (dtype = xp .bool , shape = shape ), label = "key" )
163- value = data .draw (xps .from_dtype (x .dtype ), label = "value" ) # TODO: more values
163+ value = data .draw (
164+ xps .from_dtype (x .dtype ) | xps .arrays (dtype = x .dtype , shape = ()), label = "value"
165+ )
164166
165167 res = xp .asarray (x , copy = True )
166168 res [key ] = value
167169
168- # TODO
170+ ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
171+ ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.dtype" )
172+
173+ scalar_type = dh .get_scalar_type (x .dtype )
174+ for idx in sh .ndindex (x .shape ):
175+ if key [idx ]:
176+ if isinstance (value , Scalar ):
177+ ph .assert_scalar_equals (
178+ "__setitem__" ,
179+ scalar_type ,
180+ idx ,
181+ scalar_type (res [idx ]),
182+ value ,
183+ repr_name = "modified x" ,
184+ )
185+ else :
186+ ph .assert_0d_equals (
187+ "__setitem__" , "value" , value , f"modified x[{ idx } ]" , res [idx ]
188+ )
189+ else :
190+ ph .assert_0d_equals (
191+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
192+ )
169193
170194
171195def make_param (method_name : str , dtype : DataType , stype : ScalarType ) -> Param :
0 commit comments