@@ -55,11 +55,13 @@ def test_getitem(shape, dtype, data):
5555 if i is None :
5656 out_shape .append (1 )
5757 else :
58+ side = shape [a ]
5859 if isinstance (i , int ):
60+ if i < 0 :
61+ i += side
5962 axes_indices .append ([i ])
6063 else :
6164 assert isinstance (i , slice ) # sanity check
62- side = shape [a ]
6365 indices = range (side )[i ]
6466 axes_indices .append (indices )
6567 out_shape .append (len (indices ))
@@ -102,9 +104,9 @@ def test_setitem(shape, dtypes, data):
102104 start_pos = _key .index (Ellipsis )
103105 _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
104106 out_shape = []
105- for a , i in enumerate (_key ):
107+
108+ for i , side in zip (_key , shape ):
106109 if isinstance (i , slice ):
107- side = shape [a ]
108110 indices = range (side )[i ]
109111 out_shape .append (len (indices ))
110112 out_shape = tuple (out_shape )
@@ -119,7 +121,8 @@ def test_setitem(shape, dtypes, data):
119121
120122 ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
121123 ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
122- f_res = f"res[{ sh .fmt_idx ('x' , key )} ]"
124+
125+ f_res = sh .fmt_idx ("x" , key )
123126 if isinstance (value , get_args (Scalar )):
124127 msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
125128 if math .isnan (value ):
@@ -128,14 +131,21 @@ def test_setitem(shape, dtypes, data):
128131 assert res [key ] == value , msg
129132 else :
130133 ph .assert_array_elements ("__setitem__" , res [key ], value , out_repr = f_res )
131- if all (isinstance (i , int ) for i in _key ): # TODO: normalise slices and ellipsis
132- _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
133- unaffected_indices = list (sh .ndindex (res .shape ))
134- unaffected_indices .remove (_key )
135- for idx in unaffected_indices :
136- ph .assert_0d_equals (
137- "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
138- )
134+
135+ axes_indices = []
136+ for i , side in zip (_key , shape ):
137+ if isinstance (i , int ):
138+ if i < 0 :
139+ i += side
140+ axes_indices .append ([i ])
141+ else :
142+ indices = range (side )[i ]
143+ axes_indices .append (indices )
144+ unaffected_indices = set (sh .ndindex (res .shape )) - set (product (* axes_indices ))
145+ for idx in unaffected_indices :
146+ ph .assert_0d_equals (
147+ "__setitem__" , f"old { f_res } " , x [idx ], f"modified { f_res } " , res [idx ]
148+ )
139149
140150
141151@pytest .mark .data_dependent_shapes
0 commit comments