@@ -242,25 +242,29 @@ def test_setitem_masking(shape, data):
242242 )
243243
244244
245+ # ### Fancy indexing ###
246+
245247@pytest .mark .min_version ("2024.12" )
246248@pytest .mark .unvectorized
249+ @pytest .mark .parametrize ("idx_max_dims" , [1 , None ])
247250@given (shape = hh .shapes (min_dims = 2 ), data = st .data ())
248- def test_getitem_arrays_and_ints_1 (shape , data ):
251+ def test_getitem_arrays_and_ints_1 (shape , data , idx_max_dims ):
249252 # min_dims=2 : test multidim `x` arrays
250- # index arrays are all 1D
251- _test_getitem_arrays_and_ints_1D (shape , data )
253+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
254+ _test_getitem_arrays_and_ints (shape , data , idx_max_dims )
252255
253256
254257@pytest .mark .min_version ("2024.12" )
255258@pytest .mark .unvectorized
259+ @pytest .mark .parametrize ("idx_max_dims" , [1 , None ])
256260@given (shape = hh .shapes (min_dims = 1 ), data = st .data ())
257- def test_getitem_arrays_and_ints_2 (shape , data ):
261+ def test_getitem_arrays_and_ints_2 (shape , data , idx_max_dims ):
258262 # min_dims=1 : favor 1D `x` arrays
259- # index arrays are all 1D
260- _test_getitem_arrays_and_ints_1D (shape , data )
263+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
264+ _test_getitem_arrays_and_ints (shape , data , idx_max_dims )
261265
262266
263- def _test_getitem_arrays_and_ints_1D (shape , data ):
267+ def _test_getitem_arrays_and_ints (shape , data , idx_max_dims ):
264268 assume ((len (shape ) > 0 ) and all (sh > 0 for sh in shape ))
265269
266270 dtype = xp .int32
@@ -271,11 +275,12 @@ def _test_getitem_arrays_and_ints_1D(shape, data):
271275 arr_index = [data .draw (st .booleans ()) for _ in range (len (shape ))]
272276 assume (sum (arr_index ) > 0 )
273277
274- # draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
278+ # draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY
279+ # max_dims=None ==> multidim indexing arrays
275280 if sum (arr_index ) > 0 :
276281 index_shapes = data .draw (
277282 hh .mutually_broadcastable_shapes (
278- sum (arr_index ), min_dims = 1 , max_dims = 1 , min_side = 1
283+ sum (arr_index ), min_dims = 1 , max_dims = idx_max_dims , min_side = 1
279284 )
280285 )
281286 index_shapes = list (index_shapes )
@@ -298,7 +303,7 @@ def _test_getitem_arrays_and_ints_1D(shape, data):
298303 # draw an integer
299304 key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
300305
301- # print(f"??? {x.shape = } {key = }")
306+ print (f"??? { x .shape = } { len ( key ) = } { [ xp . asarray ( k ). shape for k in key ] } " )
302307
303308 key = tuple (key )
304309 out = x [key ]
0 commit comments