Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit 523e7a8

Browse files
committed
fix test
1 parent 72bb96c commit 523e7a8

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tensornetwork/tests/ncon_interface_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tensornetwork.ncon_interface import (_get_cont_out_labels,
2020
_canonicalize_network_structure)
2121
from tensornetwork.backends.backend_factory import get_backend
22+
from tensornetwork.backends.jax.jax_backend import JaxBackend
2223
from tensornetwork.contractors import greedy
2324

2425

@@ -57,11 +58,10 @@ def test_return_type(backend):
5758
result_2 = ncon_interface.ncon([n1, n2], [(-1, 1), (1, -2)], backend=backend)
5859
result_3 = ncon_interface.ncon([n1, t2], [(-1, 1), (1, -2)], backend=backend)
5960
assert isinstance(result_2, Tensor)
60-
if backend not in ('jax', get_backend('jax')):
61-
# jitted functions return jaxlib.xla_extension.Buffer,
62-
# convert_to_tensor returns jax.interpreters.xla._DeviceArray now.
63-
assert isinstance(result_1, type(n1.backend.convert_to_tensor(t1)))
64-
assert isinstance(result_3, type(n1.backend.convert_to_tensor(t1)))
61+
if isinstance(backend, JaxBackend) or backend == 'jax':
62+
pytest.skip('return-type tests for jax not implemented')
63+
assert isinstance(result_1, type(n1.backend.convert_to_tensor(t1)))
64+
assert isinstance(result_3, type(n1.backend.convert_to_tensor(t1)))
6565

6666

6767
def test_order_spec(backend):

0 commit comments

Comments
 (0)