1+ import subprocess
2+ import sys
3+
14import numpy as np
25import pytest
36import torch
710
811from ._helpers import import_
912
10-
11- @pytest .mark .parametrize ("library" , ["cupy" , "numpy" , "torch" , "dask.array" ])
13+ @pytest .mark .parametrize ("library" , ["cupy" , "numpy" , "torch" , "dask.array" , "jax.numpy" ])
1214@pytest .mark .parametrize ("api_version" , [None , "2021.12" ])
1315def test_array_namespace (library , api_version ):
1416 xp = import_ (library )
@@ -21,9 +23,31 @@ def test_array_namespace(library, api_version):
2123 else :
2224 if library == "dask.array" :
2325 assert namespace == array_api_compat .dask .array
26+ elif library == "jax.numpy" :
27+ import jax .experimental .array_api
28+ assert namespace == jax .experimental .array_api
2429 else :
2530 assert namespace == getattr (array_api_compat , library )
2631
32+ # Check that array_namespace works even if jax.experimental.array_api
33+ # hasn't been imported yet (it monkeypatches __array_namespace__
34+ # onto JAX arrays, but we should support them regardless). The only way to
35+ # do this is to use a subprocess, since we cannot un-import it and another
36+ # test probably already imported it.
37+ if library == "jax.numpy" :
38+ code = f"""\
39+ import sys
40+ import jax.numpy
41+ import array_api_compat
42+ array = jax.numpy.asarray([1.0, 2.0, 3.0])
43+
44+ assert 'jax.experimental.array_api' not in sys.modules
45+ namespace = array_api_compat.array_namespace(array, api_version={ api_version !r} )
46+
47+ import jax.experimental.array_api
48+ assert namespace == jax.experimental.array_api
49+ """
50+ subprocess .run ([sys .executable , "-c" , code ], check = True )
2751
2852def test_array_namespace_errors ():
2953 pytest .raises (TypeError , lambda : array_namespace ([1 ]))
0 commit comments