99
1010@numba_funcify .register (SortOp )
1111def numba_funcify_SortOp (op , node , ** kwargs ):
12+ if op .kind != "quicksort" :
13+ warnings .warn (
14+ (
15+ f'Numba function sort doesn\' t support kind="{ op .kind } "'
16+ " switching to `quicksort`."
17+ ),
18+ UserWarning ,
19+ )
20+
1221 @numba_njit
1322 def sort_f (a , axis ):
1423 axis = axis .item ()
@@ -19,41 +28,11 @@ def sort_f(a, axis):
1928
2029 return a_sorted_swapped
2130
22- if op .kind != "quicksort" :
23- warnings .warn (
24- (
25- f'Numba function sort doesn\' t support kind="{ op .kind } "'
26- " switching to `quicksort`."
27- ),
28- UserWarning ,
29- )
30-
3131 return sort_f
3232
3333
3434@numba_funcify .register (ArgSortOp )
3535def numba_funcify_ArgSortOp (op , node , ** kwargs ):
36- def argsort_f_kind (kind ):
37- @numba_njit
38- def argort_vec (X , axis ):
39- axis = axis .item ()
40-
41- Y = np .swapaxes (X , axis , 0 )
42- result = np .empty_like (Y , dtype = "int64" )
43-
44- indices = list (np .ndindex (Y .shape [1 :]))
45-
46- for idx in indices :
47- result [(slice (None ), * idx )] = np .argsort (
48- Y [(slice (None ), * idx )], kind = kind
49- )
50-
51- result = np .swapaxes (result , 0 , axis )
52-
53- return result
54-
55- return argort_vec
56-
5736 kind = op .kind
5837
5938 if kind not in ["quicksort" , "mergesort" ]:
@@ -66,4 +45,19 @@ def argort_vec(X, axis):
6645 UserWarning ,
6746 )
6847
69- return argsort_f_kind (kind )
48+ @numba_njit
49+ def argort_f (X , axis ):
50+ axis = axis .item ()
51+
52+ Y = np .swapaxes (X , axis , 0 )
53+ result = np .empty_like (Y , dtype = "int64" )
54+
55+ indices = list (np .ndindex (Y .shape [1 :]))
56+
57+ for idx in indices :
58+ result [(slice (None ), * idx )] = np .argsort (Y [(slice (None ), * idx )], kind = kind )
59+
60+ result = np .swapaxes (result , 0 , axis )
61+ return result
62+
63+ return argort_f
0 commit comments