@@ -25,11 +25,15 @@ def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scal
2525 )
2626
2727
28- @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
28+ @given (hh .shapes (), st .data ())
2929def test_getitem (shape , data ):
3030 dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
31- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
32- x = xp .asarray (obj , dtype = dtype )
31+ zero_sided = any (side == 0 for side in shape )
32+ if zero_sided :
33+ x = xp .ones (shape , dtype = dtype )
34+ else :
35+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
36+ x = xp .asarray (obj , dtype = dtype )
3337 note (f"{ x = } " )
3438 key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
3539
@@ -62,16 +66,17 @@ def test_getitem(shape, data):
6266 a += 1
6367 out_shape = tuple (out_shape )
6468 ph .assert_shape ("__getitem__" , out .shape , out_shape )
65- assume (all (len (indices ) > 0 for indices in axes_indices ))
66- out_obj = []
67- for idx in product (* axes_indices ):
68- val = obj
69- for i in idx :
70- val = val [i ]
71- out_obj .append (val )
72- out_obj = sh .reshape (out_obj , out_shape )
73- expected = xp .asarray (out_obj , dtype = dtype )
74- ph .assert_array ("__getitem__" , out , expected )
69+ out_zero_sided = any (side == 0 for side in out_shape )
70+ if not zero_sided and not out_zero_sided :
71+ out_obj = []
72+ for idx in product (* axes_indices ):
73+ val = obj
74+ for i in idx :
75+ val = val [i ]
76+ out_obj .append (val )
77+ out_obj = sh .reshape (out_obj , out_shape )
78+ expected = xp .asarray (out_obj , dtype = dtype )
79+ ph .assert_array ("__getitem__" , out , expected )
7580
7681
7782@given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
0 commit comments