11import math
22from itertools import product
3- from typing import List , get_args
3+ from typing import List , Sequence , Tuple , Union , get_args
44
55import pytest
66from hypothesis import assume , given , note
1212from . import pytest_helpers as ph
1313from . import shape_helpers as sh
1414from . import xps
15- from .typing import DataType , Param , Scalar , ScalarType , Shape
15+ from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
16+ from .typing import DataType , Index , Param , Scalar , ScalarType , Shape
1617
1718pytestmark = pytest .mark .ci
1819
1920
20- def scalar_objects (dtype : DataType , shape : Shape ) -> st .SearchStrategy [List [Scalar ]]:
21+ def scalar_objects (
22+ dtype : DataType , shape : Shape
23+ ) -> st .SearchStrategy [Union [Scalar , List [Scalar ]]]:
2124 """Generates scalars or nested sequences which are valid for xp.asarray()"""
2225 size = math .prod (shape )
2326 return st .lists (xps .from_dtype (dtype ), min_size = size , max_size = size ).map (
2427 lambda l : sh .reshape (l , shape )
2528 )
2629
2730
28- @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
29- def test_getitem (shape , data ):
30- dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
31- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
32- x = xp .asarray (obj , dtype = dtype )
33- note (f"{ x = } " )
34- key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
35-
36- out = x [key ]
31+ def normalise_key (key : Index , shape : Shape ) -> Tuple [Union [int , slice ], ...]:
32+ """
33+ Normalise an indexing key.
3734
38- ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
35+ * If a non-tuple index, wrap as a tuple.
36+ * Represent ellipsis as equivalent slices.
37+ """
3938 _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
4039 if Ellipsis in _key :
4140 nonexpanding_key = tuple (i for i in _key if i is not None )
@@ -44,71 +43,109 @@ def test_getitem(shape, data):
4443 slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
4544 start_pos = _key .index (Ellipsis )
4645 _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
46+ return _key
47+
48+
49+ def get_indexed_axes_and_out_shape (
50+ key : Tuple [Union [int , slice , None ], ...], shape : Shape
51+ ) -> Tuple [Tuple [Sequence [int ], ...], Shape ]:
52+ """
53+ From the (normalised) key and input shape, calculates:
54+
55+ * indexed_axes: For each dimension, the axes which the key indexes.
56+ * out_shape: The resulting shape of indexing an array (of the input shape)
57+ with the key.
58+ """
4759 axes_indices = []
4860 out_shape = []
4961 a = 0
50- for i in _key :
62+ for i in key :
5163 if i is None :
5264 out_shape .append (1 )
5365 else :
66+ side = shape [a ]
5467 if isinstance (i , int ):
55- axes_indices .append ([i ])
68+ if i < 0 :
69+ i += side
70+ axes_indices .append ((i ,))
5671 else :
57- assert isinstance (i , slice ) # sanity check
58- side = shape [a ]
5972 indices = range (side )[i ]
6073 axes_indices .append (indices )
6174 out_shape .append (len (indices ))
6275 a += 1
63- out_shape = tuple (out_shape )
76+ return tuple (axes_indices ), tuple (out_shape )
77+
78+
79+ @given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
80+ def test_getitem (shape , dtype , data ):
81+ zero_sided = any (side == 0 for side in shape )
82+ if zero_sided :
83+ x = xp .zeros (shape , dtype = dtype )
84+ else :
85+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
86+ x = xp .asarray (obj , dtype = dtype )
87+ note (f"{ x = } " )
88+ key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
89+
90+ out = x [key ]
91+
92+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
93+ _key = normalise_key (key , shape )
94+ axes_indices , out_shape = get_indexed_axes_and_out_shape (_key , shape )
6495 ph .assert_shape ("__getitem__" , out .shape , out_shape )
65- assume (all (len (indices ) > 0 for indices in axes_indices ))
66- out_obj = []
67- for idx in product (* axes_indices ):
68- val = obj
69- for i in idx :
70- val = val [i ]
71- out_obj .append (val )
72- out_obj = sh .reshape (out_obj , out_shape )
73- expected = xp .asarray (out_obj , dtype = dtype )
74- ph .assert_array ("__getitem__" , out , expected )
75-
76-
77- @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
78- def test_setitem (shape , data ):
79- dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
80- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
81- x = xp .asarray (obj , dtype = dtype )
96+ out_zero_sided = any (side == 0 for side in out_shape )
97+ if not zero_sided and not out_zero_sided :
98+ out_obj = []
99+ for idx in product (* axes_indices ):
100+ val = obj
101+ for i in idx :
102+ val = val [i ]
103+ out_obj .append (val )
104+ out_obj = sh .reshape (out_obj , out_shape )
105+ expected = xp .asarray (out_obj , dtype = dtype )
106+ ph .assert_array_elements ("__getitem__" , out , expected )
107+
108+
109+ @given (
110+ shape = hh .shapes (),
111+ dtypes = oneway_promotable_dtypes (dh .all_dtypes ),
112+ data = st .data (),
113+ )
114+ def test_setitem (shape , dtypes , data ):
115+ zero_sided = any (side == 0 for side in shape )
116+ if zero_sided :
117+ x = xp .zeros (shape , dtype = dtypes .result_dtype )
118+ else :
119+ obj = data .draw (scalar_objects (dtypes .result_dtype , shape ), label = "obj" )
120+ x = xp .asarray (obj , dtype = dtypes .result_dtype )
82121 note (f"{ x = } " )
83- # TODO: test setting non-0d arrays
84- key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
85- value = data .draw (
86- xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
87- )
122+ key = data .draw (xps .indices (shape = shape ), label = "key" )
123+ _key = normalise_key (key , shape )
124+ axes_indices , out_shape = get_indexed_axes_and_out_shape (_key , shape )
125+ value_strat = xps .arrays (dtype = dtypes .result_dtype , shape = out_shape )
126+ if out_shape == ():
127+ # We can pass scalars if we're only indexing one element
128+ value_strat |= xps .from_dtype (dtypes .result_dtype )
129+ value = data .draw (value_strat , label = "value" )
88130
89131 res = xp .asarray (x , copy = True )
90132 res [key ] = value
91133
92134 ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
93135 ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
136+ f_res = sh .fmt_idx ("x" , key )
94137 if isinstance (value , get_args (Scalar )):
95- msg = f"x[ { key } ] ={ res [key ]!r} , but should be { value = } [__setitem__()]"
138+ msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
96139 if math .isnan (value ):
97140 assert xp .isnan (res [key ]), msg
98141 else :
99142 assert res [key ] == value , msg
100143 else :
101- ph .assert_0d_equals (
102- "__setitem__" , "value" , value , f"modified x[{ key } ]" , res [key ]
103- )
104- _key = key if isinstance (key , tuple ) else (key ,)
105- assume (all (isinstance (i , int ) for i in _key )) # TODO: normalise slices and ellipsis
106- _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
107- unaffected_indices = list (sh .ndindex (res .shape ))
108- unaffected_indices .remove (_key )
144+ ph .assert_array_elements ("__setitem__" , res [key ], value , out_repr = f_res )
145+ unaffected_indices = set (sh .ndindex (res .shape )) - set (product (* axes_indices ))
109146 for idx in unaffected_indices :
110147 ph .assert_0d_equals (
111- "__setitem__" , f"old x[ { idx } ] " , x [idx ], f"modified x[ { idx } ] " , res [idx ]
148+ "__setitem__" , f"old { f_res } " , x [idx ], f"modified { f_res } " , res [idx ]
112149 )
113150
114151
0 commit comments