@@ -25,12 +25,11 @@ def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scal
2525 )
2626
2727
28- @given (hh .shapes (), st .data ())
29- def test_getitem (shape , data ):
30- dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
28+ @given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
29+ def test_getitem (shape , dtype , data ):
3130 zero_sided = any (side == 0 for side in shape )
3231 if zero_sided :
33- x = xp .ones (shape , dtype = dtype )
32+ x = xp .zeros (shape , dtype = dtype )
3433 else :
3534 obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
3635 x = xp .asarray (obj , dtype = dtype )
@@ -76,45 +75,62 @@ def test_getitem(shape, data):
7675 out_obj .append (val )
7776 out_obj = sh .reshape (out_obj , out_shape )
7877 expected = xp .asarray (out_obj , dtype = dtype )
79- ph .assert_array ("__getitem__" , out , expected )
78+ ph .assert_array_elements ("__getitem__" , out , expected )
8079
8180
82- @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
83- def test_setitem (shape , data ):
84- dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
85- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
86- x = xp .asarray (obj , dtype = dtype )
81+ @given (shape = hh .shapes (min_side = 1 ), dtype = xps .scalar_dtypes (), data = st .data ())
82+ def test_setitem (shape , dtype , data ):
83+ zero_sided = any (side == 0 for side in shape )
84+ if zero_sided :
85+ x = xp .zeros (shape , dtype = dtype )
86+ else :
87+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
88+ x = xp .asarray (obj , dtype = dtype )
8789 note (f"{ x = } " )
88- # TODO: test setting non-0d arrays
89- key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
90- value = data .draw (
91- xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
92- )
90+ key = data .draw (xps .indices (shape = shape ), label = "key" )
91+ _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
92+ if Ellipsis in _key :
93+ nonexpanding_key = tuple (i for i in _key if i is not None )
94+ start_a = nonexpanding_key .index (Ellipsis )
95+ stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
96+ slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
97+ start_pos = _key .index (Ellipsis )
98+ _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
99+ out_shape = []
100+ for a , i in enumerate (_key ):
101+ if isinstance (i , slice ):
102+ side = shape [a ]
103+ indices = range (side )[i ]
104+ out_shape .append (len (indices ))
105+ out_shape = tuple (out_shape )
106+ value_strat = xps .arrays (dtype = dtype , shape = out_shape )
107+ if out_shape == ():
108+ # We can pass scalars if we're only indexing one element
109+ value_strat |= xps .from_dtype (dtype )
110+ value = data .draw (value_strat , label = "value" )
93111
94112 res = xp .asarray (x , copy = True )
95113 res [key ] = value
96114
97115 ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
98116 ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
117+ f_res = f"res[{ sh .fmt_idx ('x' , key )} ]"
99118 if isinstance (value , get_args (Scalar )):
100- msg = f"x[ { key } ] ={ res [key ]!r} , but should be { value = } [__setitem__()]"
119+ msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
101120 if math .isnan (value ):
102121 assert xp .isnan (res [key ]), msg
103122 else :
104123 assert res [key ] == value , msg
105124 else :
106- ph .assert_0d_equals (
107- "__setitem__" , "value" , value , f"modified x[{ key } ]" , res [key ]
108- )
109- _key = key if isinstance (key , tuple ) else (key ,)
110- assume (all (isinstance (i , int ) for i in _key )) # TODO: normalise slices and ellipsis
111- _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
112- unaffected_indices = list (sh .ndindex (res .shape ))
113- unaffected_indices .remove (_key )
114- for idx in unaffected_indices :
115- ph .assert_0d_equals (
116- "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
117- )
125+ ph .assert_array_elements ("__setitem__" , res [key ], value , out_repr = f_res )
126+ if all (isinstance (i , int ) for i in _key ): # TODO: normalise slices and ellipsis
127+ _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
128+ unaffected_indices = list (sh .ndindex (res .shape ))
129+ unaffected_indices .remove (_key )
130+ for idx in unaffected_indices :
131+ ph .assert_0d_equals (
132+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
133+ )
118134
119135
120136@pytest .mark .data_dependent_shapes
0 commit comments