Skip to content

Commit a08ef5e

Browse files
committed
Cleanup Op dispatchers
1 parent a292c7b commit a08ef5e

File tree

1 file changed

+25
-31
lines changed
  • pytensor/link/numba/dispatch

1 file changed

+25
-31
lines changed

pytensor/link/numba/dispatch/sort.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99

1010
@numba_funcify.register(SortOp)
1111
def numba_funcify_SortOp(op, node, **kwargs):
12+
if op.kind != "quicksort":
13+
warnings.warn(
14+
(
15+
f'Numba function sort doesn\'t support kind="{op.kind}"'
16+
" switching to `quicksort`."
17+
),
18+
UserWarning,
19+
)
20+
1221
@numba_njit
1322
def sort_f(a, axis):
1423
axis = axis.item()
@@ -19,41 +28,11 @@ def sort_f(a, axis):
1928

2029
return a_sorted_swapped
2130

22-
if op.kind != "quicksort":
23-
warnings.warn(
24-
(
25-
f'Numba function sort doesn\'t support kind="{op.kind}"'
26-
" switching to `quicksort`."
27-
),
28-
UserWarning,
29-
)
30-
3131
return sort_f
3232

3333

3434
@numba_funcify.register(ArgSortOp)
3535
def numba_funcify_ArgSortOp(op, node, **kwargs):
36-
def argsort_f_kind(kind):
37-
@numba_njit
38-
def argort_vec(X, axis):
39-
axis = axis.item()
40-
41-
Y = np.swapaxes(X, axis, 0)
42-
result = np.empty_like(Y, dtype="int64")
43-
44-
indices = list(np.ndindex(Y.shape[1:]))
45-
46-
for idx in indices:
47-
result[(slice(None), *idx)] = np.argsort(
48-
Y[(slice(None), *idx)], kind=kind
49-
)
50-
51-
result = np.swapaxes(result, 0, axis)
52-
53-
return result
54-
55-
return argort_vec
56-
5736
kind = op.kind
5837

5938
if kind not in ["quicksort", "mergesort"]:
@@ -66,4 +45,19 @@ def argort_vec(X, axis):
6645
UserWarning,
6746
)
6847

69-
return argsort_f_kind(kind)
48+
@numba_njit
49+
def argort_f(X, axis):
50+
axis = axis.item()
51+
52+
Y = np.swapaxes(X, axis, 0)
53+
result = np.empty_like(Y, dtype="int64")
54+
55+
indices = list(np.ndindex(Y.shape[1:]))
56+
57+
for idx in indices:
58+
result[(slice(None), *idx)] = np.argsort(Y[(slice(None), *idx)], kind=kind)
59+
60+
result = np.swapaxes(result, 0, axis)
61+
return result
62+
63+
return argort_f

0 commit comments

Comments
 (0)