22
33from functools import reduce as _reduce , wraps as _wraps
44from builtins import all as _builtin_all , any as _builtin_any
5- from typing import List , Optional , Sequence , Tuple , Union
5+ from typing import Any , List , Optional , Sequence , Tuple , Union
66
77import torch
88
99from .._internal import get_xp
1010from ..common import _aliases
11+ from ..common ._typing import NestedSequence , SupportsBufferProtocol
1112from ._info import __array_namespace_info__
1213from ._typing import Array , Device , DType
1314
@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207208remainder = _two_arg (torch .remainder )
208209subtract = _two_arg (torch .subtract )
209210
211+
212+ def asarray (
213+ obj : (
214+ Array
215+ | bool | int | float | complex
216+ | NestedSequence [bool | int | float | complex ]
217+ | SupportsBufferProtocol
218+ ),
219+ / ,
220+ * ,
221+ dtype : DType | None = None ,
222+ device : Device | None = None ,
223+ copy : bool | None = None ,
224+ ** kwargs : Any ,
225+ ) -> Array :
226+ # torch.asarray does not respect input->output device propagation
227+ # https://github.com/pytorch/pytorch/issues/150199
228+ if device is None and isinstance (obj , torch .Tensor ):
229+ device = obj .device
230+ return torch .asarray (obj , dtype = dtype , device = device , copy = copy , ** kwargs )
231+
232+
210233# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211234# of 'axis'.
212235
@@ -285,7 +308,6 @@ def prod(x: Array,
285308 dtype : Optional [DType ] = None ,
286309 keepdims : bool = False ,
287310 ** kwargs ) -> Array :
288- x = torch .asarray (x )
289311 ndim = x .ndim
290312
291313 # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -321,7 +343,6 @@ def sum(x: Array,
321343 dtype : Optional [DType ] = None ,
322344 keepdims : bool = False ,
323345 ** kwargs ) -> Array :
324- x = torch .asarray (x )
325346 ndim = x .ndim
326347
327348 # https://github.com/pytorch/pytorch/issues/29137.
@@ -351,7 +372,6 @@ def any(x: Array,
351372 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
352373 keepdims : bool = False ,
353374 ** kwargs ) -> Array :
354- x = torch .asarray (x )
355375 ndim = x .ndim
356376 if axis == ():
357377 return x .to (torch .bool )
@@ -376,7 +396,6 @@ def all(x: Array,
376396 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
377397 keepdims : bool = False ,
378398 ** kwargs ) -> Array :
379- x = torch .asarray (x )
380399 ndim = x .ndim
381400 if axis == ():
382401 return x .to (torch .bool )
@@ -819,7 +838,7 @@ def sign(x: Array, /) -> Array:
819838 return out
820839
821840
822- __all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
841+ __all__ = ['__array_namespace_info__' , 'asarray' , ' result_type' , 'can_cast' ,
823842 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
824843 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
825844 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
0 commit comments