22
33import cupy as cp
44
5- from ..common import _aliases
5+ from ..common import _aliases , _helpers
66from .._internal import get_xp
77
88from ._info import __array_namespace_info__
4646unique_counts = get_xp (cp )(_aliases .unique_counts )
4747unique_inverse = get_xp (cp )(_aliases .unique_inverse )
4848unique_values = get_xp (cp )(_aliases .unique_values )
49- astype = _aliases .astype
5049std = get_xp (cp )(_aliases .std )
5150var = get_xp (cp )(_aliases .var )
5251cumulative_sum = get_xp (cp )(_aliases .cumulative_sum )
@@ -110,6 +109,21 @@ def asarray(
110109
111110 return cp .array (obj , dtype = dtype , ** kwargs )
112111
112+
113+ def astype (
114+ x : ndarray ,
115+ dtype : Dtype ,
116+ / ,
117+ * ,
118+ copy : bool = True ,
119+ device : Optional [Device ] = None ,
120+ ) -> ndarray :
121+ if device is None :
122+ return x .astype (dtype = dtype , copy = copy )
123+ out = _helpers .to_device (x .astype (dtype = dtype , copy = False ), device )
124+ return out .copy () if copy and out is x else out
125+
126+
113127# These functions are completely new here. If the library already has them
114128# (i.e., numpy 2.0), use the library version instead of our wrapper.
115129if hasattr (cp , 'vecdot' ):
@@ -127,10 +141,10 @@ def asarray(
127141else :
128142 unstack = get_xp (cp )(_aliases .unstack )
129143
130- __all__ = _aliases .__all__ + ['__array_namespace_info__' , 'asarray' , 'bool ' ,
144+ __all__ = _aliases .__all__ + ['__array_namespace_info__' , 'asarray' , 'astype ' ,
131145 'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
132146 'atan2' , 'atanh' , 'bitwise_left_shift' ,
133147 'bitwise_invert' , 'bitwise_right_shift' ,
134- 'concat' , 'pow' , 'sign' ]
148+ 'bool' , ' concat' , 'pow' , 'sign' ]
135149
136150_all_ignore = ['cp' , 'get_xp' ]
0 commit comments