Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit 72bb96c

Browse files
committed
fix test after latest releases jax 0.2.9 and jaxlib 0.1.59
1 parent 198937c commit 72bb96c

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tensornetwork/tests/ncon_interface_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@ def test_return_type(backend):
5656
result_1 = ncon_interface.ncon([t1, t2], [(-1, 1), (1, -2)], backend=backend)
5757
result_2 = ncon_interface.ncon([n1, n2], [(-1, 1), (1, -2)], backend=backend)
5858
result_3 = ncon_interface.ncon([n1, t2], [(-1, 1), (1, -2)], backend=backend)
59-
assert isinstance(result_1, type(n1.backend.convert_to_tensor(t1)))
6059
assert isinstance(result_2, Tensor)
61-
assert isinstance(result_3, type(n1.backend.convert_to_tensor(t1)))
60+
if backend not in ('jax', get_backend('jax')):
61+
# jitted functions return jaxlib.xla_extension.Buffer,
62+
# convert_to_tensor returns jax.interpreters.xla._DeviceArray now.
63+
assert isinstance(result_1, type(n1.backend.convert_to_tensor(t1)))
64+
assert isinstance(result_3, type(n1.backend.convert_to_tensor(t1)))
6265

6366

6467
def test_order_spec(backend):

0 commit comments

Comments
 (0)