|
15 | 15 |
|
16 | 16 |
|
17 | 17 | pytest.importorskip("numba") |
18 | | -from pytensor.link.numba.dispatch import numba_funcify |
19 | 18 |
|
20 | 19 |
|
21 | 20 | rng = np.random.default_rng(42849) |
@@ -274,28 +273,19 @@ def test_ExtractDiag(val, offset): |
274 | 273 | ) |
275 | 274 |
|
276 | 275 |
|
277 | | -@pytest.mark.parametrize("k", range(-5, 4)) |
278 | | -@pytest.mark.parametrize( |
279 | | - "axis1, axis2", ((0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)) |
280 | | -) |
281 | | -@pytest.mark.parametrize("reverse_axis", (False, True)) |
282 | | -def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis): |
283 | | - from pytensor.link.numba.dispatch.basic import numba_njit |
284 | | - |
285 | | - if reverse_axis: |
286 | | - axis1, axis2 = axis2, axis1 |
287 | | - |
| 276 | +@pytest.mark.parametrize("k", (-5, -1, 0, 1, 4)) |
| 277 | +@pytest.mark.parametrize("axis1, axis2", ((0, 1), (0, 3), (1, 2), (2, 1), (2, 3))) |
| 278 | +def test_ExtractDiag_exhaustive(k, axis1, axis2): |
288 | 279 | x = pt.tensor4("x") |
289 | 280 | x_shape = (2, 3, 4, 5) |
290 | 281 | x_test = np.arange(np.prod(x_shape)).reshape(x_shape) |
291 | 282 | out = pt.diagonal(x, k, axis1, axis2) |
292 | | - numba_fn = numba_funcify(out.owner.op, out.owner) |
293 | 283 |
|
294 | | - @numba_njit(no_cpython_wrapper=False) |
295 | | - def wrap(x): |
296 | | - return numba_fn(x) |
297 | | - |
298 | | - np.testing.assert_allclose(wrap(x_test), np.diagonal(x_test, k, axis1, axis2)) |
| 284 | + compare_numba_and_py( |
| 285 | + [x], |
| 286 | + out, |
| 287 | + [x_test], |
| 288 | + ) |
299 | 289 |
|
300 | 290 |
|
301 | 291 | @pytest.mark.parametrize( |
|
0 commit comments