1010from . import dtype_helpers as dh
1111from . import hypothesis_helpers as hh
1212from . import pytest_helpers as ph
13+ from . import shape_helpers as sh
1314from . import xps
1415from .typing import DataType , Param , Scalar , ScalarType , Shape
1516
@@ -87,6 +88,7 @@ def test_setitem(shape, data):
8788 )
8889 x = xp .asarray (obj , dtype = dtype )
8990 note (f"{ x = } " )
91+ # TODO: test setting non-0d arrays
9092 key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
9193 value = data .draw (
9294 xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
@@ -104,10 +106,100 @@ def test_setitem(shape, data):
104106 else :
105107 assert res [key ] == value , msg
106108 else :
107- ph .assert_0d_equals ("__setitem__" , "value" , value , f"x[{ key } ]" , res [key ])
109+ ph .assert_0d_equals (
110+ "__setitem__" , "value" , value , f"modified x[{ key } ]" , res [key ]
111+ )
112+ _key = key if isinstance (key , tuple ) else (key ,)
113+ assume (all (isinstance (i , int ) for i in _key )) # TODO: normalise slices and ellipsis
114+ _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
115+ unaffected_indices = list (sh .ndindex (res .shape ))
116+ unaffected_indices .remove (_key )
117+ for idx in unaffected_indices :
118+ ph .assert_0d_equals (
119+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
120+ )
121+
122+
123+ # TODO: make mask tests optional
124+
125+
126+ @given (hh .shapes (), st .data ())
127+ def test_getitem_masking (shape , data ):
128+ x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
129+ mask_shapes = st .one_of (
130+ st .sampled_from ([x .shape , ()]),
131+ st .lists (st .booleans (), min_size = x .ndim , max_size = x .ndim ).map (
132+ lambda l : tuple (s if b else 0 for s , b in zip (x .shape , l ))
133+ ),
134+ hh .shapes (),
135+ )
136+ key = data .draw (xps .arrays (dtype = xp .bool , shape = mask_shapes ), label = "key" )
108137
138+ if key .ndim > x .ndim or not all (
139+ ks in (xs , 0 ) for xs , ks in zip (x .shape , key .shape )
140+ ):
141+ with pytest .raises (IndexError ):
142+ x [key ]
143+ return
144+
145+ out = x [key ]
109146
110- # TODO: test boolean indexing
147+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
148+ if key .ndim == 0 :
149+ out_shape = (1 ,) if key else (0 ,)
150+ out_shape += x .shape
151+ else :
152+ size = int (xp .sum (xp .astype (key , xp .uint8 )))
153+ out_shape = (size ,) + x .shape [key .ndim :]
154+ ph .assert_shape ("__getitem__" , out .shape , out_shape )
155+ if not any (s == 0 for s in key .shape ):
156+ assume (key .ndim == x .ndim ) # TODO: test key.ndim < x.ndim scenarios
157+ out_indices = sh .ndindex (out .shape )
158+ for x_idx in sh .ndindex (x .shape ):
159+ if key [x_idx ]:
160+ out_idx = next (out_indices )
161+ ph .assert_0d_equals (
162+ "__getitem__" ,
163+ f"x[{ x_idx } ]" ,
164+ x [x_idx ],
165+ f"out[{ out_idx } ]" ,
166+ out [out_idx ],
167+ )
168+
169+
170+ @given (hh .shapes (), st .data ())
171+ def test_setitem_masking (shape , data ):
172+ x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
173+ key = data .draw (xps .arrays (dtype = xp .bool , shape = shape ), label = "key" )
174+ value = data .draw (
175+ xps .from_dtype (x .dtype ) | xps .arrays (dtype = x .dtype , shape = ()), label = "value"
176+ )
177+
178+ res = xp .asarray (x , copy = True )
179+ res [key ] = value
180+
181+ ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
182+ ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.dtype" )
183+ scalar_type = dh .get_scalar_type (x .dtype )
184+ for idx in sh .ndindex (x .shape ):
185+ if key [idx ]:
186+ if isinstance (value , Scalar ):
187+ ph .assert_scalar_equals (
188+ "__setitem__" ,
189+ scalar_type ,
190+ idx ,
191+ scalar_type (res [idx ]),
192+ value ,
193+ repr_name = "modified x" ,
194+ )
195+ else :
196+ ph .assert_0d_equals (
197+ "__setitem__" , "value" , value , f"modified x[{ idx } ]" , res [idx ]
198+ )
199+ else :
200+ ph .assert_0d_equals (
201+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
202+ )
111203
112204
113205def make_param (method_name : str , dtype : DataType , stype : ScalarType ) -> Param :
0 commit comments