@@ -242,6 +242,81 @@ def test_setitem_masking(shape, data):
242242 )
243243
244244
245+ # ### Fancy indexing ###
246+
247+ @pytest .mark .min_version ("2024.12" )
248+ @pytest .mark .unvectorized
249+ @pytest .mark .parametrize ("idx_max_dims" , [1 , None ])
250+ @given (shape = hh .shapes (min_dims = 2 ), data = st .data ())
251+ def test_getitem_arrays_and_ints_1 (shape , data , idx_max_dims ):
252+ # min_dims=2 : test multidim `x` arrays
253+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
254+ _test_getitem_arrays_and_ints (shape , data , idx_max_dims )
255+
256+
257+ @pytest .mark .min_version ("2024.12" )
258+ @pytest .mark .unvectorized
259+ @pytest .mark .parametrize ("idx_max_dims" , [1 , None ])
260+ @given (shape = hh .shapes (min_dims = 1 ), data = st .data ())
261+ def test_getitem_arrays_and_ints_2 (shape , data , idx_max_dims ):
262+ # min_dims=1 : favor 1D `x` arrays
263+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
264+ _test_getitem_arrays_and_ints (shape , data , idx_max_dims )
265+
266+
267+ def _test_getitem_arrays_and_ints (shape , data , idx_max_dims ):
268+ assume ((len (shape ) > 0 ) and all (sh > 0 for sh in shape ))
269+
270+ dtype = xp .int32
271+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
272+ x = xp .asarray (obj , dtype = dtype )
273+
274+ # draw a mix of ints and index arrays
275+ arr_index = [data .draw (st .booleans ()) for _ in range (len (shape ))]
276+ assume (sum (arr_index ) > 0 )
277+
278+ # draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY
279+ # max_dims=None ==> multidim indexing arrays
280+ if sum (arr_index ) > 0 :
281+ index_shapes = data .draw (
282+ hh .mutually_broadcastable_shapes (
283+ sum (arr_index ), min_dims = 1 , max_dims = idx_max_dims , min_side = 1
284+ )
285+ )
286+ index_shapes = list (index_shapes )
287+
288+ # prepare the indexing tuple, a mix of integer indices and index arrays
289+ key = []
290+ for i ,typ in enumerate (arr_index ):
291+ if typ :
292+ # draw an array index
293+ this_idx = data .draw (
294+ xps .arrays (
295+ dtype ,
296+ shape = index_shapes .pop (),
297+ elements = st .integers (0 , shape [i ]- 1 )
298+ )
299+ )
300+ key .append (this_idx )
301+
302+ else :
303+ # draw an integer
304+ key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
305+
306+ print (f"??? { x .shape = } { len (key ) = } { [xp .asarray (k ).shape for k in key ]} " )
307+
308+ key = tuple (key )
309+ out = x [key ]
310+
311+ arrays = [xp .asarray (k ) for k in key ]
312+ bcast_shape = sh .broadcast_shapes (* [arr .shape for arr in arrays ])
313+ bcast_key = [xp .broadcast_to (arr , bcast_shape ) for arr in arrays ]
314+
315+ for idx in sh .ndindex (bcast_shape ):
316+ tpl = tuple (k [idx ] for k in bcast_key )
317+ assert out [idx ] == x [tpl ], f"failing at { idx = } w/ { key = } "
318+
319+
245320def make_scalar_casting_param (
246321 method_name : str , dtype : DataType , stype : ScalarType
247322) -> Param :
0 commit comments