Skip to content

Commit b6ad0a4

Browse files
committed
Make test_ExtractDiag_exhaustive less exhaustive
1 parent aaf5a4e commit b6ad0a4

File tree

1 file changed

+8
-18
lines changed

1 file changed

+8
-18
lines changed

tests/link/numba/test_tensor_basic.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616

1717
pytest.importorskip("numba")
18-
from pytensor.link.numba.dispatch import numba_funcify
1918

2019

2120
rng = np.random.default_rng(42849)
@@ -274,28 +273,19 @@ def test_ExtractDiag(val, offset):
274273
)
275274

276275

277-
@pytest.mark.parametrize("k", range(-5, 4))
278-
@pytest.mark.parametrize(
279-
"axis1, axis2", ((0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3))
280-
)
281-
@pytest.mark.parametrize("reverse_axis", (False, True))
282-
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
283-
from pytensor.link.numba.dispatch.basic import numba_njit
284-
285-
if reverse_axis:
286-
axis1, axis2 = axis2, axis1
287-
276+
@pytest.mark.parametrize("k", (-5, -1, 0, 1, 4))
277+
@pytest.mark.parametrize("axis1, axis2", ((0, 1), (0, 3), (1, 2), (2, 1), (2, 3)))
278+
def test_ExtractDiag_exhaustive(k, axis1, axis2):
288279
x = pt.tensor4("x")
289280
x_shape = (2, 3, 4, 5)
290281
x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
291282
out = pt.diagonal(x, k, axis1, axis2)
292-
numba_fn = numba_funcify(out.owner.op, out.owner)
293283

294-
@numba_njit(no_cpython_wrapper=False)
295-
def wrap(x):
296-
return numba_fn(x)
297-
298-
np.testing.assert_allclose(wrap(x_test), np.diagonal(x_test, k, axis1, axis2))
284+
compare_numba_and_py(
285+
[x],
286+
out,
287+
[x_test],
288+
)
299289

300290

301291
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)