|
19 | 19 | from tensornetwork.ncon_interface import (_get_cont_out_labels, |
20 | 20 | _canonicalize_network_structure) |
21 | 21 | from tensornetwork.backends.backend_factory import get_backend |
| 22 | +from tensornetwork.backends.jax.jax_backend import JaxBackend |
22 | 23 | from tensornetwork.contractors import greedy |
23 | 24 |
|
24 | 25 |
|
@@ -57,11 +58,10 @@ def test_return_type(backend): |
57 | 58 | result_2 = ncon_interface.ncon([n1, n2], [(-1, 1), (1, -2)], backend=backend) |
58 | 59 | result_3 = ncon_interface.ncon([n1, t2], [(-1, 1), (1, -2)], backend=backend) |
59 | 60 | assert isinstance(result_2, Tensor) |
60 | | - if backend not in ('jax', get_backend('jax')): |
61 | | - # jitted functions return jaxlib.xla_extension.Buffer, |
62 | | - # convert_to_tensor returns jax.interpreters.xla._DeviceArray now. |
63 | | - assert isinstance(result_1, type(n1.backend.convert_to_tensor(t1))) |
64 | | - assert isinstance(result_3, type(n1.backend.convert_to_tensor(t1))) |
| 61 | + if isinstance(backend, JaxBackend) or backend == 'jax': |
| 62 | + pytest.skip('return-type tests for jax not implemented') |
| 63 | + assert isinstance(result_1, type(n1.backend.convert_to_tensor(t1))) |
| 64 | + assert isinstance(result_3, type(n1.backend.convert_to_tensor(t1))) |
65 | 65 |
|
66 | 66 |
|
67 | 67 | def test_order_spec(backend): |
|
0 commit comments