11import math
22from itertools import product
3- from typing import List , Union , get_args
3+ from typing import List , Sequence , Tuple , Union , get_args
44
55import pytest
66from hypothesis import assume , given , note
@@ -28,7 +28,7 @@ def scalar_objects(
2828 )
2929
3030
31- def normalise_key (key : Index , shape : Shape ):
31+ def normalise_key (key : Index , shape : Shape ) -> Tuple [ Union [ int , slice ], ...] :
3232 """
3333 Normalise an indexing key.
3434
@@ -46,40 +46,52 @@ def normalise_key(key: Index, shape: Shape):
4646 return _key
4747
4848
49- @given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
50- def test_getitem (shape , dtype , data ):
51- zero_sided = any (side == 0 for side in shape )
52- if zero_sided :
53- x = xp .zeros (shape , dtype = dtype )
54- else :
55- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
56- x = xp .asarray (obj , dtype = dtype )
57- note (f"{ x = } " )
58- key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
59-
60- out = x [key ]
49+ def get_indexed_axes_and_out_shape (
50+ key : Tuple [Union [int , slice , None ], ...], shape : Shape
51+ ) -> Tuple [Tuple [Sequence [int ], ...], Shape ]:
52+ """
53+ From the (normalised) key and input shape, calculates:
6154
62- ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
63- _key = normalise_key (key , shape )
55+ * indexed_axes: For each dimension, the axes which the key indexes.
56+ * out_shape: The resulting shape of indexing an array (of the input shape)
57+ with the key.
58+ """
6459 axes_indices = []
6560 out_shape = []
6661 a = 0
67- for i in _key :
62+ for i in key :
6863 if i is None :
6964 out_shape .append (1 )
7065 else :
7166 side = shape [a ]
7267 if isinstance (i , int ):
7368 if i < 0 :
7469 i += side
75- axes_indices .append ([ i ] )
70+ axes_indices .append (( i ,) )
7671 else :
77- assert isinstance (i , slice ) # sanity check
7872 indices = range (side )[i ]
7973 axes_indices .append (indices )
8074 out_shape .append (len (indices ))
8175 a += 1
82- out_shape = tuple (out_shape )
76+ return tuple (axes_indices ), tuple (out_shape )
77+
78+
79+ @given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
80+ def test_getitem (shape , dtype , data ):
81+ zero_sided = any (side == 0 for side in shape )
82+ if zero_sided :
83+ x = xp .zeros (shape , dtype = dtype )
84+ else :
85+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
86+ x = xp .asarray (obj , dtype = dtype )
87+ note (f"{ x = } " )
88+ key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
89+
90+ out = x [key ]
91+
92+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
93+ _key = normalise_key (key , shape )
94+ axes_indices , out_shape = get_indexed_axes_and_out_shape (_key , shape )
8395 ph .assert_shape ("__getitem__" , out .shape , out_shape )
8496 out_zero_sided = any (side == 0 for side in out_shape )
8597 if not zero_sided and not out_zero_sided :
@@ -109,13 +121,7 @@ def test_setitem(shape, dtypes, data):
109121 note (f"{ x = } " )
110122 key = data .draw (xps .indices (shape = shape ), label = "key" )
111123 _key = normalise_key (key , shape )
112- out_shape = []
113-
114- for i , side in zip (_key , shape ):
115- if isinstance (i , slice ):
116- indices = range (side )[i ]
117- out_shape .append (len (indices ))
118- out_shape = tuple (out_shape )
124+ axes_indices , out_shape = get_indexed_axes_and_out_shape (_key , shape )
119125 value_strat = xps .arrays (dtype = dtypes .result_dtype , shape = out_shape )
120126 if out_shape == ():
121127 # We can pass scalars if we're only indexing one element
@@ -127,7 +133,6 @@ def test_setitem(shape, dtypes, data):
127133
128134 ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
129135 ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
130-
131136 f_res = sh .fmt_idx ("x" , key )
132137 if isinstance (value , get_args (Scalar )):
133138 msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
@@ -137,16 +142,6 @@ def test_setitem(shape, dtypes, data):
137142 assert res [key ] == value , msg
138143 else :
139144 ph .assert_array_elements ("__setitem__" , res [key ], value , out_repr = f_res )
140-
141- axes_indices = []
142- for i , side in zip (_key , shape ):
143- if isinstance (i , int ):
144- if i < 0 :
145- i += side
146- axes_indices .append ([i ])
147- else :
148- indices = range (side )[i ]
149- axes_indices .append (indices )
150145 unaffected_indices = set (sh .ndindex (res .shape )) - set (product (* axes_indices ))
151146 for idx in unaffected_indices :
152147 ph .assert_0d_equals (
0 commit comments