Skip to content

Commit adcf07d

Browse files
committed
Move NonZero Op dispatcher to tensor_basic
1 parent 704d4a3 commit adcf07d

File tree

4 files changed

+32
-28
lines changed

4 files changed

+32
-28
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
)
2525
from pytensor.scalar.basic import ScalarType
2626
from pytensor.sparse import SparseTensorType
27-
from pytensor.tensor.basic import Nonzero
2827
from pytensor.tensor.blas import BatchedDot
2928
from pytensor.tensor.math import Dot
3029
from pytensor.tensor.type import TensorType
@@ -457,15 +456,3 @@ def ifelse(cond, *args):
457456
return res[0]
458457

459458
return ifelse
460-
461-
462-
@numba_funcify.register(Nonzero)
463-
def numba_funcify_Nonzero(op, node, **kwargs):
464-
@numba_njit
465-
def nonzero(a):
466-
result_tuple = np.nonzero(a)
467-
if a.ndim == 1:
468-
return result_tuple[0]
469-
return list(result_tuple)
470-
471-
return nonzero

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import numpy as np
44

55
from pytensor.link.numba.dispatch import basic as numba_basic
6-
from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify
6+
from pytensor.link.numba.dispatch.basic import (
7+
create_tuple_string,
8+
numba_funcify,
9+
numba_njit,
10+
)
711
from pytensor.link.utils import compile_function_src, unique_name_generator
812
from pytensor.tensor.basic import (
913
Alloc,
@@ -13,6 +17,7 @@
1317
Eye,
1418
Join,
1519
MakeVector,
20+
Nonzero,
1621
ScalarFromTensor,
1722
Split,
1823
TensorFromScalar,
@@ -235,3 +240,15 @@ def scalar_from_tensor(x):
235240
return numba_basic.to_scalar(x)
236241

237242
return scalar_from_tensor
243+
244+
245+
@numba_funcify.register(Nonzero)
246+
def numba_funcify_Nonzero(op, node, **kwargs):
247+
@numba_njit
248+
def nonzero(a):
249+
result_tuple = np.nonzero(a)
250+
if a.ndim == 1:
251+
return result_tuple[0]
252+
return list(result_tuple)
253+
254+
return nonzero

tests/link/numba/test_basic.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -753,20 +753,6 @@ def test_function_overhead(mode, benchmark):
753753
benchmark(fn, test_x)
754754

755755

756-
@pytest.mark.parametrize(
757-
"input_data",
758-
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
759-
)
760-
def test_Nonzero(input_data):
761-
a = pt.tensor("a", shape=(None,) * input_data.ndim)
762-
763-
graph_outputs = pt.nonzero(a)
764-
765-
compare_numba_and_py(
766-
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
767-
)
768-
769-
770756
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
771757
def test_mat_vec_dot_performance(dtype, benchmark):
772758
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)

tests/link/numba/test_tensor_basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,17 @@ def test_Eye(n, m, k, dtype):
326326
g,
327327
[n_test, m_test] if m is not None else [n_test],
328328
)
329+
330+
331+
@pytest.mark.parametrize(
332+
"input_data",
333+
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
334+
)
335+
def test_Nonzero(input_data):
336+
a = pt.tensor("a", shape=(None,) * input_data.ndim)
337+
338+
graph_outputs = pt.nonzero(a)
339+
340+
compare_numba_and_py(
341+
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
342+
)

0 commit comments

Comments
 (0)