@@ -426,6 +426,22 @@ def test_tile(x, data):
426426def test_unstack (x , data ):
427427 axis = data .draw (st .integers (min_value = - x .ndim , max_value = x .ndim - 1 ), label = "axis" )
428428 kw = data .draw (hh .specified_kwargs (("axis" , axis , 0 )), label = "kw" )
429- out = xp .asarray (xp .unstack (x , ** kw ), dtype = x .dtype )
430- ph .assert_dtype ("unstack" , in_dtype = x .dtype , out_dtype = out .dtype )
431- # TODO: shapes and values testing
429+ out = xp .unstack (x , ** kw )
430+
431+ assert isinstance (out , tuple )
432+ assert len (out ) == x .shape [axis ]
433+ expected_shape = list (x .shape )
434+ expected_shape .pop (axis )
435+ expected_shape = tuple (expected_shape )
436+ for i in range (x .shape [axis ]):
437+ arr = out [i ]
438+ ph .assert_result_shape ("unstack" , in_shapes = [x .shape ],
439+ out_shape = arr .shape , expected = expected_shape ,
440+ kw = kw , repr_name = f"out[{ i } ].shape" )
441+
442+ ph .assert_dtype ("unstack" , in_dtype = x .dtype , out_dtype = arr .dtype ,
443+ repr_name = f"out[{ i } ].dtype" )
444+
445+ idx = [slice (None )] * x .ndim
446+ idx [axis ] = i
447+ ph .assert_array_elements ("unstack" , out = arr , expected = x [tuple (idx )], kw = kw , out_repr = f"out[{ i } ]" )
0 commit comments