@@ -46,29 +46,35 @@ def test_getitem(shape, data):
4646
4747 out = x [key ]
4848
49+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
50+
4951 _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
5052 if Ellipsis in _key :
5153 start_a = _key .index (Ellipsis )
5254 stop_a = start_a + (len (shape ) - (len (_key ) - 1 ))
5355 slices = tuple (slice (None , None ) for _ in range (start_a , stop_a ))
5456 _key = _key [:start_a ] + slices + _key [start_a + 1 :]
5557 axes_indices = []
58+ out_shape = []
5659 for a , i in enumerate (_key ):
5760 if isinstance (i , int ):
5861 axes_indices .append ([i ])
5962 else :
6063 side = shape [a ]
6164 indices = range (side )[i ]
62- assume (len (indices ) > 0 ) # TODO: test 0-sided arrays
6365 axes_indices .append (indices )
64- expected = []
66+ out_shape .append (len (indices ))
67+ out_shape = tuple (out_shape )
68+ ph .assert_shape ("__getitem__" , out .shape , out_shape )
69+ assume (all (len (indices ) > 0 for indices in axes_indices ))
70+ out_obj = []
6571 for idx in product (* axes_indices ):
6672 val = obj
6773 for i in idx :
6874 val = val [i ]
69- expected .append (val )
70- expected = reshape (expected , out . shape )
71- expected = xp .asarray (expected , dtype = dtype )
75+ out_obj .append (val )
76+ out_obj = reshape (out_obj , out_shape )
77+ expected = xp .asarray (out_obj , dtype = dtype )
7278 ph .assert_array ("__getitem__" , out , expected )
7379
7480
0 commit comments