@@ -380,8 +380,6 @@ def one_hot(
380380 / ,
381381 num_classes : int ,
382382 * ,
383- supports_fancy_indexing : bool = False ,
384- supports_array_indexing : bool = False ,
385383 dtype : DType ,
386384 xp : ModuleType ,
387385) -> Array : # numpydoc ignore=PR01,RT01
@@ -394,19 +392,16 @@ def one_hot(
394392 # specification.
395393 msg = "x must have a concrete size."
396394 raise TypeError (msg )
397- out = xp .zeros ((x .size , num_classes ), dtype = dtype , device = _compat .device (x ))
398- x_flattened = xp .reshape (x , (- 1 ,))
399- if supports_fancy_indexing :
400- out = at (out )[xp .arange (x_size ), x_flattened ].set (1 )
401- else :
402- for i in range (x_size ):
403- x_i = x_flattened [i ]
404- if not supports_array_indexing :
405- x_i = int (x_i )
406- out = at (out )[i , x_i ].set (1 )
407- if x .ndim != 1 :
408- out = xp .reshape (out , (* x .shape , num_classes ))
409- return out
395+ # TODO: Benchmark whether this is faster on the numpy backend:
396+ # x_flattened = xp.reshape(x, (-1,))
397+ # out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398+ # at(out)[xp.arange(x_size), x_flattened].set(1)
399+ # if x.ndim != 1:
400+ # out = xp.reshape(out, (*x.shape, num_classes))
401+ out = x [..., None ] == xp .arange (
402+ num_classes , dtype = x .dtype , device = _compat .device (x )
403+ )
404+ return xp .astype (out , dtype )
410405
411406
412407def create_diagonal (
0 commit comments