1313from . import shape_helpers as sh
1414from . import xps
1515from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
16- from .typing import DataType , Param , Scalar , ScalarType , Shape
16+ from .typing import DataType , Index , Param , Scalar , ScalarType , Shape
1717
1818pytestmark = pytest .mark .ci
1919
@@ -28,6 +28,24 @@ def scalar_objects(
2828 )
2929
3030
31+ def normalise_key (key : Index , shape : Shape ):
32+ """
33+ Normalise an indexing key.
34+
35+ * If a non-tuple index, wrap as a tuple.
36+ * Represent ellipsis as equivalent slices.
37+ """
38+ _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
39+ if Ellipsis in _key :
40+ nonexpanding_key = tuple (i for i in _key if i is not None )
41+ start_a = nonexpanding_key .index (Ellipsis )
42+ stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
43+ slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
44+ start_pos = _key .index (Ellipsis )
45+ _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
46+ return _key
47+
48+
3149@given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
3250def test_getitem (shape , dtype , data ):
3351 zero_sided = any (side == 0 for side in shape )
@@ -42,14 +60,7 @@ def test_getitem(shape, dtype, data):
4260 out = x [key ]
4361
4462 ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
45- _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
46- if Ellipsis in _key :
47- nonexpanding_key = tuple (i for i in _key if i is not None )
48- start_a = nonexpanding_key .index (Ellipsis )
49- stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
50- slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
51- start_pos = _key .index (Ellipsis )
52- _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
63+ _key = normalise_key (key , shape )
5364 axes_indices = []
5465 out_shape = []
5566 a = 0
@@ -97,14 +108,7 @@ def test_setitem(shape, dtypes, data):
97108 x = xp .asarray (obj , dtype = dtypes .result_dtype )
98109 note (f"{ x = } " )
99110 key = data .draw (xps .indices (shape = shape ), label = "key" )
100- _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
101- if Ellipsis in _key :
102- nonexpanding_key = tuple (i for i in _key if i is not None )
103- start_a = nonexpanding_key .index (Ellipsis )
104- stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
105- slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
106- start_pos = _key .index (Ellipsis )
107- _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
111+ _key = normalise_key (key , shape )
108112 out_shape = []
109113
110114 for i , side in zip (_key , shape ):
0 commit comments