22
33from functools import reduce as _reduce , wraps as _wraps
44from builtins import all as _builtin_all , any as _builtin_any
5- from typing import List , Optional , Sequence , Tuple , Union
5+ from typing import Any , List , Optional , Sequence , Tuple , Union
66
77import torch
88
99from .._internal import get_xp
1010from ..common import _aliases
11+ from ..common ._typing import NestedSequence , SupportsBufferProtocol
1112from ._info import __array_namespace_info__
1213from ._typing import Array , Device , DType
1314
@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207208remainder = _two_arg (torch .remainder )
208209subtract = _two_arg (torch .subtract )
209210
211+
212+ def asarray (
213+ obj : (
214+ Array
215+ | bool | int | float | complex
216+ | NestedSequence [bool | int | float | complex ]
217+ | SupportsBufferProtocol
218+ ),
219+ / ,
220+ * ,
221+ dtype : DType | None = None ,
222+ device : Device | None = None ,
223+ copy : bool | None = None ,
224+ ** kwargs : Any ,
225+ ) -> Array :
226+ # torch.asarray does not respect input->output device propagation
227+ # https://github.com/pytorch/pytorch/issues/150199
228+ if device is None and isinstance (obj , torch .Tensor ):
229+ device = obj .device
230+ return torch .asarray (obj , dtype = dtype , device = device , copy = copy , ** kwargs )
231+
232+
210233# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211234# of 'axis'.
212235
@@ -282,7 +305,6 @@ def prod(x: Array,
282305 dtype : Optional [DType ] = None ,
283306 keepdims : bool = False ,
284307 ** kwargs ) -> Array :
285- x = torch .asarray (x )
286308 ndim = x .ndim
287309
288310 # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -318,7 +340,6 @@ def sum(x: Array,
318340 dtype : Optional [DType ] = None ,
319341 keepdims : bool = False ,
320342 ** kwargs ) -> Array :
321- x = torch .asarray (x )
322343 ndim = x .ndim
323344
324345 # https://github.com/pytorch/pytorch/issues/29137.
@@ -348,7 +369,6 @@ def any(x: Array,
348369 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
349370 keepdims : bool = False ,
350371 ** kwargs ) -> Array :
351- x = torch .asarray (x )
352372 ndim = x .ndim
353373 if axis == ():
354374 return x .to (torch .bool )
@@ -373,7 +393,6 @@ def all(x: Array,
373393 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
374394 keepdims : bool = False ,
375395 ** kwargs ) -> Array :
376- x = torch .asarray (x )
377396 ndim = x .ndim
378397 if axis == ():
379398 return x .to (torch .bool )
@@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array:
816835 return out
817836
818837
819- __all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
838+ __all__ = ['__array_namespace_info__' , 'asarray' , ' result_type' , 'can_cast' ,
820839 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
821840 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
822841 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
0 commit comments