@@ -126,13 +126,13 @@ def one_hot(
126126 """
127127 One-hot encode the given indices.
128128
129- Each index in the input ``x`` is encoded as a vector of zeros of length
130- ``num_classes`` with the element at the given index set to one.
129+ Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
130+ with the element at the given index set to one.
131131
132132 Parameters
133133 ----------
134134 x : array
135- An array with integral dtype having shape ``batch_dims`` .
135+ An array with integral dtype and concrete size (``x.size`` cannot be `None`) .
136136 num_classes : int
137137 Number of classes in the one-hot dimension.
138138 dtype : DType, optional
@@ -147,17 +147,20 @@ def one_hot(
147147 -------
148148 array
149149 An array having the same shape as `x` except for a new axis at the position
150- given by `axis` having size `num_classes`.
150+ given by `axis` having size `num_classes`. If `axis` is unspecified, it
151+ defaults to -1, which appends a new axis.
151152
152153 If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
153154 an exception, or may even cause a bad state. `x` is not checked.
154155
155156 Examples
156157 --------
157- >>> xp.one_hot(jnp.array([1, 2, 0]), 3)
158+ >>> import array_api_extra as xpx
159+ >>> import array-api-strict as xp
160+ >>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
158161 Array([[0., 1., 0.],
159- [0., 0., 1.],
160- [1., 0., 0.]], dtype=float64)
162+ [0., 0., 1.],
163+ [1., 0., 0.]], dtype=array_api_strict. float64)
161164 """
162165 # Validate inputs.
163166 if xp is None :
0 commit comments