@@ -474,6 +474,27 @@ def test_basic(self, xp: ModuleType):
474474 expected = xp .asarray ([[0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 0.0 ]])
475475 xp_assert_equal (actual , expected )
476476
477+ def test_2d (self , xp : ModuleType ):
478+ actual = one_hot (xp .asarray ([[2 , 1 , 0 ], [1 , 0 , 2 ]]), 3 , axis = 1 )
479+ expected = xp .asarray (
480+ [
481+ [[0.0 , 0.0 , 1.0 ], [0.0 , 1.0 , 0.0 ], [1.0 , 0.0 , 0.0 ]],
482+ [[0.0 , 1.0 , 0.0 ], [1.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]],
483+ ]
484+ )
485+ xp_assert_equal (actual , expected )
486+
487+ @pytest .mark .skip_xp_backend (
488+ Backend .ARRAY_API_STRICTEST , reason = "backend doesn't support Boolean indexing"
489+ )
490+ def test_abstract_size (self , xp : ModuleType ):
491+ x = xp .arange (5 )
492+ x = x [x > 2 ]
493+ x = xp .astype (x , xp .int64 )
494+ actual = one_hot (x , 5 )
495+ expected = xp .asarray ([[0.0 , 0.0 , 0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 0.0 , 1.0 ]])
496+ xp_assert_equal (actual , expected )
497+
477498 @pytest .mark .skip_xp_backend (
478499 Backend .TORCH_GPU , reason = "Puts Pytorch into a bad state."
479500 )
0 commit comments