@@ -88,6 +88,7 @@ def test_setitem(shape, data):
8888 )
8989 x = xp .asarray (obj , dtype = dtype )
9090 note (f"{ x = } " )
91+ # TODO: test setting non-0d arrays
9192 key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
9293 value = data .draw (
9394 xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
@@ -105,7 +106,18 @@ def test_setitem(shape, data):
105106 else :
106107 assert res [key ] == value , msg
107108 else :
108- ph .assert_0d_equals ("__setitem__" , "value" , value , f"x[{ key } ]" , res [key ])
109+ ph .assert_0d_equals (
110+ "__setitem__" , "value" , value , f"modified x[{ key } ]" , res [key ]
111+ )
112+ _key = key if isinstance (key , tuple ) else (key ,)
113+ assume (all (isinstance (i , int ) for i in _key )) # TODO: normalise slices and ellipsis
114+ _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
115+ unaffected_indices = list (sh .ndindex (res .shape ))
116+ unaffected_indices .remove (_key )
117+ for idx in unaffected_indices :
118+ ph .assert_0d_equals (
119+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
120+ )
109121
110122
111123# TODO: make mask tests optional
@@ -140,7 +152,6 @@ def test_getitem_mask(shape, data):
140152 size = int (xp .sum (xp .astype (key , xp .uint8 )))
141153 out_shape = (size ,) + x .shape [key .ndim :]
142154 ph .assert_shape ("__getitem__" , out .shape , out_shape )
143-
144155 if not any (s == 0 for s in key .shape ):
145156 assume (key .ndim == x .ndim ) # TODO: test key.ndim < x.ndim scenarios
146157 out_indices = sh .ndindex (out .shape )
@@ -169,7 +180,6 @@ def test_setitem_mask(shape, data):
169180
170181 ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
171182 ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.dtype" )
172-
173183 scalar_type = dh .get_scalar_type (x .dtype )
174184 for idx in sh .ndindex (x .shape ):
175185 if key [idx ]:
0 commit comments