11from __future__ import annotations
22
33from ...common import _aliases
4- from ...common ._helpers import _check_device
54
65from ..._internal import get_xp
76
4039isdtype = get_xp (np )(_aliases .isdtype )
4140unstack = get_xp (da )(_aliases .unstack )
4241
42+ # da.astype doesn't respect copy=True
4343def astype (
4444 x : Array ,
4545 dtype : Dtype ,
4646 / ,
4747 * ,
4848 copy : bool = True ,
49- device : Device | None = None
49+ device : Optional [ Device ] = None
5050) -> Array :
51+ """
52+ Array API compatibility wrapper for astype().
53+
54+ See the corresponding documentation in the array library and/or the array API
55+ specification for more details.
56+ """
5157 # TODO: respect device keyword?
58+
5259 if not copy and dtype == x .dtype :
5360 return x
54- # dask astype doesn't respect copy=True,
55- # so call copy manually afterwards
5661 x = x .astype (dtype )
5762 return x .copy () if copy else x
5863
@@ -61,20 +66,24 @@ def astype(
6166# This arange func is modified from the common one to
6267# not pass stop/step as keyword arguments, which will cause
6368# an error with dask
64-
65- # TODO: delete the xp stuff, it shouldn't be necessary
66- def _dask_arange (
69+ def arange (
6770 start : Union [int , float ],
6871 / ,
6972 stop : Optional [Union [int , float ]] = None ,
7073 step : Union [int , float ] = 1 ,
7174 * ,
72- xp ,
7375 dtype : Optional [Dtype ] = None ,
7476 device : Optional [Device ] = None ,
7577 ** kwargs ,
7678) -> Array :
77- _check_device (xp , device )
79+ """
80+ Array API compatibility wrapper for arange().
81+
82+ See the corresponding documentation in the array library and/or the array API
83+ specification for more details.
84+ """
85+ # TODO: respect device keyword?
86+
7887 args = [start ]
7988 if stop is not None :
8089 args .append (stop )
@@ -83,13 +92,12 @@ def _dask_arange(
8392 # prepend the default value for start which is 0
8493 args .insert (0 , 0 )
8594 args .append (step )
86- return xp .arange (* args , dtype = dtype , ** kwargs )
8795
88- arange = get_xp ( da )( _dask_arange )
89- eye = get_xp ( da )( _aliases . eye )
96+ return da . arange ( * args , dtype = dtype , ** kwargs )
97+
9098
91- linspace = get_xp (da )(_aliases .linspace )
9299eye = get_xp (da )(_aliases .eye )
100+ linspace = get_xp (da )(_aliases .linspace )
93101UniqueAllResult = get_xp (da )(_aliases .UniqueAllResult )
94102UniqueCountsResult = get_xp (da )(_aliases .UniqueCountsResult )
95103UniqueInverseResult = get_xp (da )(_aliases .UniqueInverseResult )
@@ -112,7 +120,6 @@ def _dask_arange(
112120reshape = get_xp (da )(_aliases .reshape )
113121matrix_transpose = get_xp (da )(_aliases .matrix_transpose )
114122vecdot = get_xp (da )(_aliases .vecdot )
115-
116123nonzero = get_xp (da )(_aliases .nonzero )
117124ceil = get_xp (np )(_aliases .ceil )
118125floor = get_xp (np )(_aliases .floor )
@@ -121,6 +128,7 @@ def _dask_arange(
121128tensordot = get_xp (np )(_aliases .tensordot )
122129sign = get_xp (np )(_aliases .sign )
123130
131+
124132# asarray also adds the copy keyword, which is not present in numpy 1.0.
125133def asarray (
126134 obj : Union [
@@ -135,7 +143,7 @@ def asarray(
135143 * ,
136144 dtype : Optional [Dtype ] = None ,
137145 device : Optional [Device ] = None ,
138- copy : " Optional[Union[bool, np._CopyMode]]" = None ,
146+ copy : Optional [Union [bool , np ._CopyMode ]] = None ,
139147 ** kwargs ,
140148) -> Array :
141149 """
@@ -144,6 +152,8 @@ def asarray(
144152 See the corresponding documentation in the array library and/or the array API
145153 specification for more details.
146154 """
155+ # TODO: respect device keyword?
156+
147157 if isinstance (obj , da .Array ):
148158 if dtype is not None and dtype != obj .dtype :
149159 if copy is False :
@@ -183,38 +193,40 @@ def asarray(
183193# Furthermore, the masking workaround in common._aliases.clip cannot work with
184194# dask (meaning uint64 promoting to float64 is going to just be unfixed for
185195# now).
186- @get_xp (da )
187196def clip (
188197 x : Array ,
189198 / ,
190199 min : Optional [Union [int , float , Array ]] = None ,
191200 max : Optional [Union [int , float , Array ]] = None ,
192- * ,
193- xp ,
194201) -> Array :
202+ """
203+ Array API compatibility wrapper for clip().
204+
205+ See the corresponding documentation in the array library and/or the array API
206+ specification for more details.
207+ """
195208 def _isscalar (a ):
196209 return isinstance (a , (int , float , type (None )))
197210 min_shape = () if _isscalar (min ) else min .shape
198211 max_shape = () if _isscalar (max ) else max .shape
199212
200213 # TODO: This won't handle dask unknown shapes
201- import numpy as np
202214 result_shape = np .broadcast_shapes (x .shape , min_shape , max_shape )
203215
204216 if min is not None :
205- min = xp .broadcast_to (xp .asarray (min ), result_shape )
217+ min = da .broadcast_to (da .asarray (min ), result_shape )
206218 if max is not None :
207- max = xp .broadcast_to (xp .asarray (max ), result_shape )
219+ max = da .broadcast_to (da .asarray (max ), result_shape )
208220
209221 if min is None and max is None :
210- return xp .positive (x )
222+ return da .positive (x )
211223
212224 if min is None :
213- return astype (xp .minimum (x , max ), x .dtype )
225+ return astype (da .minimum (x , max ), x .dtype )
214226 if max is None :
215- return astype (xp .maximum (x , min ), x .dtype )
227+ return astype (da .maximum (x , min ), x .dtype )
216228
217- return astype (xp .minimum (xp .maximum (x , min ), max ), x .dtype )
229+ return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
218230
219231# exclude these from all since dask.array has no sorting functions
220232_da_unsupported = ['sort' , 'argsort' ]
0 commit comments