Skip to content

Commit 704d4a3

Browse files
committed
Move numba sort Ops dispatchers to their own file
1 parent afe712e commit 704d4a3

File tree

5 files changed

+142
-128
lines changed

5 files changed

+142
-128
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pytensor.link.numba.dispatch.shape
1313
import pytensor.link.numba.dispatch.signal
1414
import pytensor.link.numba.dispatch.slinalg
15+
import pytensor.link.numba.dispatch.sort
1516
import pytensor.link.numba.dispatch.sparse
1617
import pytensor.link.numba.dispatch.subtensor
1718
import pytensor.link.numba.dispatch.tensor_basic

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pytensor.tensor.basic import Nonzero
2828
from pytensor.tensor.blas import BatchedDot
2929
from pytensor.tensor.math import Dot
30-
from pytensor.tensor.sort import ArgSortOp, SortOp
3130
from pytensor.tensor.type import TensorType
3231

3332

@@ -317,68 +316,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
317316
return deepcopyop
318317

319318

320-
@numba_funcify.register(SortOp)
321-
def numba_funcify_SortOp(op, node, **kwargs):
322-
@numba_njit
323-
def sort_f(a, axis):
324-
axis = axis.item()
325-
326-
a_swapped = np.swapaxes(a, axis, -1)
327-
a_sorted = np.sort(a_swapped)
328-
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
329-
330-
return a_sorted_swapped
331-
332-
if op.kind != "quicksort":
333-
warnings.warn(
334-
(
335-
f'Numba function sort doesn\'t support kind="{op.kind}"'
336-
" switching to `quicksort`."
337-
),
338-
UserWarning,
339-
)
340-
341-
return sort_f
342-
343-
344-
@numba_funcify.register(ArgSortOp)
345-
def numba_funcify_ArgSortOp(op, node, **kwargs):
346-
def argsort_f_kind(kind):
347-
@numba_njit
348-
def argort_vec(X, axis):
349-
axis = axis.item()
350-
351-
Y = np.swapaxes(X, axis, 0)
352-
result = np.empty_like(Y, dtype="int64")
353-
354-
indices = list(np.ndindex(Y.shape[1:]))
355-
356-
for idx in indices:
357-
result[(slice(None), *idx)] = np.argsort(
358-
Y[(slice(None), *idx)], kind=kind
359-
)
360-
361-
result = np.swapaxes(result, 0, axis)
362-
363-
return result
364-
365-
return argort_vec
366-
367-
kind = op.kind
368-
369-
if kind not in ["quicksort", "mergesort"]:
370-
kind = "quicksort"
371-
warnings.warn(
372-
(
373-
f'Numba function argsort doesn\'t support kind="{op.kind}"'
374-
" switching to `quicksort`."
375-
),
376-
UserWarning,
377-
)
378-
379-
return argsort_f_kind(kind)
380-
381-
382319
@numba.extending.intrinsic
383320
def direct_cast(typingctx, val, typ):
384321
if isinstance(typ, numba.types.TypeRef):
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import warnings
2+
3+
import numpy as np
4+
5+
from pytensor.link.numba.dispatch import numba_funcify
6+
from pytensor.link.numba.dispatch.basic import numba_njit
7+
from pytensor.tensor.sort import ArgSortOp, SortOp
8+
9+
10+
@numba_funcify.register(SortOp)
11+
def numba_funcify_SortOp(op, node, **kwargs):
12+
@numba_njit
13+
def sort_f(a, axis):
14+
axis = axis.item()
15+
16+
a_swapped = np.swapaxes(a, axis, -1)
17+
a_sorted = np.sort(a_swapped)
18+
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
19+
20+
return a_sorted_swapped
21+
22+
if op.kind != "quicksort":
23+
warnings.warn(
24+
(
25+
f'Numba function sort doesn\'t support kind="{op.kind}"'
26+
" switching to `quicksort`."
27+
),
28+
UserWarning,
29+
)
30+
31+
return sort_f
32+
33+
34+
@numba_funcify.register(ArgSortOp)
35+
def numba_funcify_ArgSortOp(op, node, **kwargs):
36+
def argsort_f_kind(kind):
37+
@numba_njit
38+
def argort_vec(X, axis):
39+
axis = axis.item()
40+
41+
Y = np.swapaxes(X, axis, 0)
42+
result = np.empty_like(Y, dtype="int64")
43+
44+
indices = list(np.ndindex(Y.shape[1:]))
45+
46+
for idx in indices:
47+
result[(slice(None), *idx)] = np.argsort(
48+
Y[(slice(None), *idx)], kind=kind
49+
)
50+
51+
result = np.swapaxes(result, 0, axis)
52+
53+
return result
54+
55+
return argort_vec
56+
57+
kind = op.kind
58+
59+
if kind not in ["quicksort", "mergesort"]:
60+
kind = "quicksort"
61+
warnings.warn(
62+
(
63+
f'Numba function argsort doesn\'t support kind="{op.kind}"'
64+
" switching to `quicksort`."
65+
),
66+
UserWarning,
67+
)
68+
69+
return argsort_f_kind(kind)

tests/link/numba/test_basic.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from pytensor.scalar.basic import ScalarOp, as_scalar
3333
from pytensor.tensor import blas, tensor
3434
from pytensor.tensor.elemwise import Elemwise
35-
from pytensor.tensor.sort import ArgSortOp, SortOp
3635

3736

3837
if TYPE_CHECKING:
@@ -356,70 +355,6 @@ def test_create_numba_signature(v, expected, force_scalar):
356355
assert res == expected
357356

358357

359-
@pytest.mark.parametrize(
360-
"x",
361-
[
362-
[], # Empty list
363-
[3, 2, 1], # Simple list
364-
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
365-
],
366-
)
367-
@pytest.mark.parametrize("axis", [0, -1, None])
368-
@pytest.mark.parametrize(
369-
("kind", "exc"),
370-
[
371-
["quicksort", None],
372-
["mergesort", UserWarning],
373-
["heapsort", UserWarning],
374-
["stable", UserWarning],
375-
],
376-
)
377-
def test_Sort(x, axis, kind, exc):
378-
if axis:
379-
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
380-
else:
381-
g = SortOp(kind)(pt.as_tensor_variable(x))
382-
383-
cm = contextlib.suppress() if not exc else pytest.warns(exc)
384-
385-
with cm:
386-
compare_numba_and_py([], [g], [])
387-
388-
389-
@pytest.mark.parametrize(
390-
"x",
391-
[
392-
[], # Empty list
393-
[3, 2, 1], # Simple list
394-
None, # Multi-dimensional array (see below)
395-
],
396-
)
397-
@pytest.mark.parametrize("axis", [0, -1, None])
398-
@pytest.mark.parametrize(
399-
("kind", "exc"),
400-
[
401-
["quicksort", None],
402-
["heapsort", None],
403-
["stable", UserWarning],
404-
],
405-
)
406-
def test_ArgSort(x, axis, kind, exc):
407-
if x is None:
408-
x = np.arange(5 * 5 * 5 * 5)
409-
np.random.shuffle(x)
410-
x = np.reshape(x, (5, 5, 5, 5))
411-
412-
if axis:
413-
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
414-
else:
415-
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
416-
417-
cm = contextlib.suppress() if not exc else pytest.warns(exc)
418-
419-
with cm:
420-
compare_numba_and_py([], [g], [])
421-
422-
423358
def test_ViewOp():
424359
v = pt.vector()
425360
v_test_value = np.arange(4, dtype=config.floatX)

tests/link/numba/test_sort.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import contextlib
2+
3+
import numpy as np
4+
import pytest
5+
from link.numba.test_basic import compare_numba_and_py
6+
7+
from pytensor import tensor as pt
8+
from pytensor.tensor.sort import ArgSortOp, SortOp
9+
10+
11+
@pytest.mark.parametrize(
12+
"x",
13+
[
14+
[], # Empty list
15+
[3, 2, 1], # Simple list
16+
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
17+
],
18+
)
19+
@pytest.mark.parametrize("axis", [0, -1, None])
20+
@pytest.mark.parametrize(
21+
("kind", "exc"),
22+
[
23+
["quicksort", None],
24+
["mergesort", UserWarning],
25+
["heapsort", UserWarning],
26+
["stable", UserWarning],
27+
],
28+
)
29+
def test_Sort(x, axis, kind, exc):
30+
if axis:
31+
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
32+
else:
33+
g = SortOp(kind)(pt.as_tensor_variable(x))
34+
35+
cm = contextlib.suppress() if not exc else pytest.warns(exc)
36+
37+
with cm:
38+
compare_numba_and_py([], [g], [])
39+
40+
41+
@pytest.mark.parametrize(
42+
"x",
43+
[
44+
[], # Empty list
45+
[3, 2, 1], # Simple list
46+
None, # Multi-dimensional array (see below)
47+
],
48+
)
49+
@pytest.mark.parametrize("axis", [0, -1, None])
50+
@pytest.mark.parametrize(
51+
("kind", "exc"),
52+
[
53+
["quicksort", None],
54+
["heapsort", None],
55+
["stable", UserWarning],
56+
],
57+
)
58+
def test_ArgSort(x, axis, kind, exc):
59+
if x is None:
60+
x = np.arange(5 * 5 * 5 * 5)
61+
np.random.shuffle(x)
62+
x = np.reshape(x, (5, 5, 5, 5))
63+
64+
if axis:
65+
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
66+
else:
67+
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
68+
69+
cm = contextlib.suppress() if not exc else pytest.warns(exc)
70+
71+
with cm:
72+
compare_numba_and_py([], [g], [])

0 commit comments

Comments
 (0)