@@ -242,6 +242,76 @@ def test_setitem_masking(shape, data):
242242 )
243243
244244
245+ @pytest .mark .min_version ("2024.12" )
246+ @pytest .mark .unvectorized
247+ @given (shape = hh .shapes (min_dims = 2 ), data = st .data ())
248+ def test_getitem_arrays_and_ints_1 (shape , data ):
249+ # min_dims=2 : test multidim `x` arrays
250+ # index arrays are all 1D
251+ _test_getitem_arrays_and_ints_1D (shape , data )
252+
253+
254+ @pytest .mark .min_version ("2024.12" )
255+ @pytest .mark .unvectorized
256+ @given (shape = hh .shapes (min_dims = 1 ), data = st .data ())
257+ def test_getitem_arrays_and_ints_2 (shape , data ):
258+ # min_dims=1 : favor 1D `x` arrays
259+ # index arrays are all 1D
260+ _test_getitem_arrays_and_ints_1D (shape , data )
261+
262+
263+ def _test_getitem_arrays_and_ints_1D (shape , data ):
264+ assume ((len (shape ) > 0 ) and all (sh > 0 for sh in shape ))
265+
266+ dtype = xp .int32
267+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
268+ x = xp .asarray (obj , dtype = dtype )
269+
270+ # draw a mix of ints and index arrays
271+ arr_index = [data .draw (st .booleans ()) for _ in range (len (shape ))]
272+ assume (sum (arr_index ) > 0 )
273+
274+ # draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
275+ if sum (arr_index ) > 0 :
276+ index_shapes = data .draw (
277+ hh .mutually_broadcastable_shapes (
278+ sum (arr_index ), min_dims = 1 , max_dims = 1 , min_side = 1
279+ )
280+ )
281+ index_shapes = list (index_shapes )
282+
283+ # prepare the indexing tuple, a mix of integer indices and index arrays
284+ key = []
285+ for i ,typ in enumerate (arr_index ):
286+ if typ :
287+ # draw an array index
288+ this_idx = data .draw (
289+ xps .arrays (
290+ dtype ,
291+ shape = index_shapes .pop (),
292+ elements = st .integers (0 , shape [i ]- 1 )
293+ )
294+ )
295+ key .append (this_idx )
296+
297+ else :
298+ # draw an integer
299+ key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
300+
301+ # print(f"??? {x.shape = } {key = }")
302+
303+ key = tuple (key )
304+ out = x [key ]
305+
306+ arrays = [xp .asarray (k ) for k in key ]
307+ bcast_shape = sh .broadcast_shapes (* [arr .shape for arr in arrays ])
308+ bcast_key = [xp .broadcast_to (arr , bcast_shape ) for arr in arrays ]
309+
310+ for idx in sh .ndindex (bcast_shape ):
311+ tpl = tuple (k [idx ] for k in bcast_key )
312+ assert out [idx ] == x [tpl ], f"failing at { idx = } w/ { key = } "
313+
314+
245315def make_scalar_casting_param (
246316 method_name : str , dtype : DataType , stype : ScalarType
247317) -> Param :
0 commit comments