11from __future__ import annotations
22
3- from typing import Literal
3+ from typing import Callable
44
55from ...common import _aliases , array_namespace
66
3131)
3232
3333from typing import TYPE_CHECKING
34+
3435if TYPE_CHECKING :
3536 from typing import Optional , Union
3637
37- from ...common ._typing import Device , Dtype , Array , NestedSequence , SupportsBufferProtocol
38+ from ...common ._typing import (
39+ Device ,
40+ Dtype ,
41+ Array ,
42+ NestedSequence ,
43+ SupportsBufferProtocol ,
44+ )
3845
3946import dask .array as da
4047
4148isdtype = get_xp (np )(_aliases .isdtype )
4249unstack = get_xp (da )(_aliases .unstack )
4350
51+
4452# da.astype doesn't respect copy=True
4553def astype (
46- x : Array ,
47- dtype : Dtype ,
48- / ,
49- * ,
50- copy : bool = True ,
51- device : Optional [Device ] = None
54+ x : Array , dtype : Dtype , / , * , copy : bool = True , device : Optional [Device ] = None
5255) -> Array :
5356 """
5457 Array API compatibility wrapper for astype().
@@ -63,8 +66,10 @@ def astype(
6366 x = x .astype (dtype )
6467 return x .copy () if copy else x
6568
69+
6670# Common aliases
6771
72+
6873# This arange func is modified from the common one to
6974# not pass stop/step as keyword arguments, which will cause
7075# an error with dask
@@ -191,6 +196,7 @@ def asarray(
191196 concatenate as concat ,
192197)
193198
199+
194200# dask.array.clip does not work unless all three arguments are provided.
195201# Furthermore, the masking workaround in common._aliases.clip cannot work with
196202# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -207,8 +213,10 @@ def clip(
207213 See the corresponding documentation in the array library and/or the array API
208214 specification for more details.
209215 """
216+
210217 def _isscalar (a ):
211218 return isinstance (a , (int , float , type (None )))
219+
212220 min_shape = () if _isscalar (min ) else min .shape
213221 max_shape = () if _isscalar (max ) else max .shape
214222
@@ -231,6 +239,35 @@ def _isscalar(a):
231239 return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
232240
233241
242+ def _ensure_single_chunk (x : Array , axis : int ) -> tuple [Array , Callable [[Array ], Array ]]:
243+ """
244+ Make sure that Array is not broken into multiple chunks along axis.
245+
246+ Returns
247+ -------
248+ x : Array
249+ The input Array with a single chunk along axis.
250+ restore : Callable[Array, Array]
251+ function to apply to the output to rechunk it back into reasonable chunks
252+ """
253+ if axis < 0 :
254+ axis += x .ndim
255+ if x .numblocks [axis ] < 2 :
256+ return x , lambda x : x
257+
258+ # Break chunks on other axes in an attempt to keep chunk size low
259+ x = x .rechunk ({i : - 1 if i == axis else "auto" for i in range (x .ndim )})
260+
261+ # Rather than reconstructing the original chunks, which can be a
262+ # very expensive affair, just break down oversized chunks without
263+ # incurring in any transfers over the network.
264+ # This has the downside of a risk of overchunking if the array is
265+ # then used in operations against other arrays that match the
266+ # original chunking pattern.
267+ restore = lambda x : x .rechunk ()
268+ return x , restore
269+
270+
234271def sort (
235272 x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
236273) -> Array :
@@ -245,7 +282,20 @@ def sort(
245282 See the corresponding documentation in the array library and/or the array API
246283 specification for more details.
247284 """
248- return _sort_argsort ("sort" , x , axis = axis , descending = descending , stable = stable )
285+ x , restore = _ensure_single_chunk (x , axis )
286+
287+ meta_xp = array_namespace (x ._meta )
288+ x = da .map_blocks (
289+ meta_xp .sort ,
290+ x ,
291+ axis = axis ,
292+ meta = x ._meta ,
293+ dtype = x .dtype ,
294+ descending = descending ,
295+ stable = stable ,
296+ )
297+
298+ return restore (x )
249299
250300
251301def argsort (
@@ -262,49 +312,22 @@ def argsort(
262312 This function temporarily rechunks the array along `axis` into a single chunk.
263313 This can be extremely inefficient and can lead to out-of-memory errors.
264314 """
265- return _sort_argsort ("argsort" , x , axis = axis , descending = descending , stable = stable )
266-
267-
268- def _sort_argsort (
269- func : Literal ["sort" , "argsort" ],
270- x : Array ,
271- / ,
272- * ,
273- axis : int ,
274- descending : bool ,
275- stable : bool ,
276- ) -> Array :
277- """
278- Implementation of sort() and argsort()
315+ x , restore = _ensure_single_chunk (x , axis )
279316
280- TODO Implement sort and argsort properly in Dask on top of the shuffle subsystem.
281- """
282- if axis < 0 :
283- axis += x .ndim
284- rechunk = False
285- if x .numblocks [axis ] > 1 :
286- rechunk = True
287- # Break chunks on other axes in an attempt to keep chunk size low
288- x = x .rechunk ({i : - 1 if i == axis else "auto" for i in range (x .ndim )})
289317 meta_xp = array_namespace (x ._meta )
318+ dtype = meta_xp .argsort (x ._meta ).dtype
319+ meta = meta_xp .astype (x ._meta , dtype )
290320 x = da .map_blocks (
291- getattr ( meta_xp , func ) ,
321+ meta_xp . argsort ,
292322 x ,
293323 axis = axis ,
324+ meta = meta ,
325+ dtype = dtype ,
294326 descending = descending ,
295327 stable = stable ,
296- dtype = x .dtype ,
297- meta = x ._meta ,
298328 )
299- if rechunk :
300- # rather than reconstructing the original chunks, which can be a
301- # very expensive affair, just break down oversized chunks without
302- # incurring in any transfers over the network.
303- # This has the downside of a risk of overchunking if the array is
304- # then used in operations against other arrays that match the
305- # original chunking pattern.
306- x = x .rechunk ()
307- return x
329+
330+ return restore (x )
308331
309332
310333_common_aliases = _aliases .__all__
@@ -318,4 +341,4 @@ def _sort_argsort(
318341 'complex64' , 'complex128' , 'iinfo' , 'finfo' ,
319342 'can_cast' , 'result_type' ]
320343
321- _all_ignore = ["Literal " , "array_namespace" , "get_xp" , "da" , "np" ]
344+ _all_ignore = ["Callable " , "array_namespace" , "get_xp" , "da" , "np" ]
0 commit comments