Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit 5098980

Browse files
committed
Fix pytorch trace tests
1 parent aae89ae commit 5098980

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

tensornetwork/backends/pytorch/pytorch_backend_test.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -621,16 +621,12 @@ def test_trace(dtype, offset, axis1, axis2):
621621
shape = (5, 5, 5, 5)
622622
backend = pytorch_backend.PyTorchBackend()
623623
array = backend.randn(shape, dtype=dtype, seed=10)
624-
if offset != 0:
625-
with pytest.raises(NotImplementedError):
626-
actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2)
627-
628-
elif axis1 == axis2:
629-
with pytest.raises(ValueError):
624+
if axis1 == axis2:
625+
with pytest.raises(RuntimeError):
630626
actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2)
631627
else:
632628
actual = backend.trace(array, offset=offset, axis1=axis1, axis2=axis2)
633-
expected = np.trace(array, axis1=axis1, axis2=axis2)
629+
expected = np.trace(array, offset=offset, axis1=axis1, axis2=axis2)
634630
np.testing.assert_allclose(actual, expected, atol=1e-6, rtol=1e-6)
635631

636632

0 commit comments

Comments
 (0)