We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bb3e4ea commit 507afa4Copy full SHA for 507afa4
src/array_api_extra/_delegation.py
@@ -169,7 +169,7 @@ def one_hot(
169
msg = "x must have an integral dtype."
170
raise TypeError(msg)
171
if dtype is None:
172
- dtype = xp.empty(()).dtype # Default float dtype
+ dtype = xp.__array_namespace_info__().default_dtypes(device=get_device(x))["real floating"]
173
# Delegate where possible.
174
if is_jax_namespace(xp):
175
assert is_jax_array(x)
0 commit comments