1212from . import pytest_helpers as ph
1313from . import shape_helpers as sh
1414from . import xps
15+ from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
1516from .typing import DataType , Param , Scalar , ScalarType , Shape
1617
1718pytestmark = pytest .mark .ci
@@ -78,14 +79,18 @@ def test_getitem(shape, dtype, data):
7879 ph .assert_array_elements ("__getitem__" , out , expected )
7980
8081
81- @given (shape = hh .shapes (min_side = 1 ), dtype = xps .scalar_dtypes (), data = st .data ())
82- def test_setitem (shape , dtype , data ):
82+ @given (
83+ shape = hh .shapes (),
84+ dtypes = oneway_promotable_dtypes (dh .all_dtypes ),
85+ data = st .data (),
86+ )
87+ def test_setitem (shape , dtypes , data ):
8388 zero_sided = any (side == 0 for side in shape )
8489 if zero_sided :
85- x = xp .zeros (shape , dtype = dtype )
90+ x = xp .zeros (shape , dtype = dtypes . result_dtype )
8691 else :
87- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
88- x = xp .asarray (obj , dtype = dtype )
92+ obj = data .draw (scalar_objects (dtypes . result_dtype , shape ), label = "obj" )
93+ x = xp .asarray (obj , dtype = dtypes . result_dtype )
8994 note (f"{ x = } " )
9095 key = data .draw (xps .indices (shape = shape ), label = "key" )
9196 _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
@@ -103,10 +108,10 @@ def test_setitem(shape, dtype, data):
103108 indices = range (side )[i ]
104109 out_shape .append (len (indices ))
105110 out_shape = tuple (out_shape )
106- value_strat = xps .arrays (dtype = dtype , shape = out_shape )
111+ value_strat = xps .arrays (dtype = dtypes . result_dtype , shape = out_shape )
107112 if out_shape == ():
108113 # We can pass scalars if we're only indexing one element
109- value_strat |= xps .from_dtype (dtype )
114+ value_strat |= xps .from_dtype (dtypes . result_dtype )
110115 value = data .draw (value_strat , label = "value" )
111116
112117 res = xp .asarray (x , copy = True )
0 commit comments