@@ -332,17 +332,19 @@ def argsort(
332332 ** kwargs ,
333333) -> ndarray :
334334 # Note: this keyword argument is different, and the default is different.
335- kind = "stable" if stable else "quicksort"
335+ # We set it in kwargs like this because numpy.sort uses kind='quicksort'
336+ # as the default whereas cupy.sort uses kind=None.
337+ if stable :
338+ kwargs ['kind' ] = "stable"
336339 if not descending :
337- res = xp .argsort (x , axis = axis , kind = kind , ** kwargs )
340+ res = xp .argsort (x , axis = axis , ** kwargs )
338341 else :
339342 # As NumPy has no native descending sort, we imitate it here. Note that
340343 # simply flipping the results of xp.argsort(x, ...) would not
341344 # respect the relative order like it would in native descending sorts.
342345 res = xp .flip (
343- xp .argsort (xp .flip (x , axis = axis ), axis = axis , kind = kind ),
346+ xp .argsort (xp .flip (x , axis = axis ), axis = axis , ** kwargs ),
344347 axis = axis ,
345- ** kwargs ,
346348 )
347349 # Rely on flip()/argsort() to validate axis
348350 normalised_axis = axis if axis >= 0 else x .ndim + axis
@@ -355,8 +357,11 @@ def sort(
355357 ** kwargs ,
356358) -> ndarray :
357359 # Note: this keyword argument is different, and the default is different.
358- kind = "stable" if stable else "quicksort"
359- res = xp .sort (x , axis = axis , kind = kind , ** kwargs )
360+ # We set it in kwargs like this because numpy.sort uses kind='quicksort'
361+ # as the default whereas cupy.sort uses kind=None.
362+ if stable :
363+ kwargs ['kind' ] = "stable"
364+ res = xp .sort (x , axis = axis , ** kwargs )
360365 if descending :
361366 res = xp .flip (res , axis = axis )
362367 return res
0 commit comments