1111
1212if TYPE_CHECKING :
1313 from typing import Optional , Union , Any
14- from ._typing import Array , Device
14+ from ._typing import Array , Device
1515
1616import sys
1717import math
18+ import inspect
1819
19- def _is_numpy_array (x ):
20+ def is_numpy_array (x ):
2021 # Avoid importing NumPy if it isn't already
2122 if 'numpy' not in sys .modules :
2223 return False
@@ -26,7 +27,7 @@ def _is_numpy_array(x):
2627 # TODO: Should we reject ndarray subclasses?
2728 return isinstance (x , (np .ndarray , np .generic ))
2829
29- def _is_cupy_array (x ):
30+ def is_cupy_array (x ):
3031 # Avoid importing NumPy if it isn't already
3132 if 'cupy' not in sys .modules :
3233 return False
@@ -36,7 +37,7 @@ def _is_cupy_array(x):
3637 # TODO: Should we reject ndarray subclasses?
3738 return isinstance (x , (cp .ndarray , cp .generic ))
3839
39- def _is_torch_array (x ):
40+ def is_torch_array (x ):
4041 # Avoid importing torch if it isn't already
4142 if 'torch' not in sys .modules :
4243 return False
@@ -46,7 +47,7 @@ def _is_torch_array(x):
4647 # TODO: Should we reject ndarray subclasses?
4748 return isinstance (x , torch .Tensor )
4849
49- def _is_dask_array (x ):
50+ def is_dask_array (x ):
5051 # Avoid importing dask if it isn't already
5152 if 'dask.array' not in sys .modules :
5253 return False
@@ -55,14 +56,24 @@ def _is_dask_array(x):
5556
5657 return isinstance (x , dask .array .Array )
5758
59+ def is_jax_array (x ):
60+ # Avoid importing jax if it isn't already
61+ if 'jax' not in sys .modules :
62+ return False
63+
64+ import jax
65+
66+ return isinstance (x , jax .Array )
67+
5868def is_array_api_obj (x ):
5969 """
6070 Check if x is an array API compatible array object.
6171 """
62- return _is_numpy_array (x ) \
63- or _is_cupy_array (x ) \
64- or _is_torch_array (x ) \
65- or _is_dask_array (x ) \
72+ return is_numpy_array (x ) \
73+ or is_cupy_array (x ) \
74+ or is_torch_array (x ) \
75+ or is_dask_array (x ) \
76+ or is_jax_array (x ) \
6677 or hasattr (x , '__array_namespace__' )
6778
6879def _check_api_version (api_version ):
@@ -87,37 +98,43 @@ def your_function(x, y):
8798 """
8899 namespaces = set ()
89100 for x in xs :
90- if _is_numpy_array (x ):
101+ if is_numpy_array (x ):
91102 _check_api_version (api_version )
92103 if _use_compat :
93104 from .. import numpy as numpy_namespace
94105 namespaces .add (numpy_namespace )
95106 else :
96107 import numpy as np
97108 namespaces .add (np )
98- elif _is_cupy_array (x ):
109+ elif is_cupy_array (x ):
99110 _check_api_version (api_version )
100111 if _use_compat :
101112 from .. import cupy as cupy_namespace
102113 namespaces .add (cupy_namespace )
103114 else :
104115 import cupy as cp
105116 namespaces .add (cp )
106- elif _is_torch_array (x ):
117+ elif is_torch_array (x ):
107118 _check_api_version (api_version )
108119 if _use_compat :
109120 from .. import torch as torch_namespace
110121 namespaces .add (torch_namespace )
111122 else :
112123 import torch
113124 namespaces .add (torch )
114- elif _is_dask_array (x ):
125+ elif is_dask_array (x ):
115126 _check_api_version (api_version )
116127 if _use_compat :
117128 from ..dask import array as dask_namespace
118129 namespaces .add (dask_namespace )
119130 else :
120131 raise TypeError ("_use_compat cannot be False if input array is a dask array!" )
132+ elif is_jax_array (x ):
133+ _check_api_version (api_version )
134+ # jax.experimental.array_api is already an array namespace. We do
135+ # not have a wrapper submodule for it.
136+ import jax .experimental .array_api as jnp
137+ namespaces .add (jnp )
121138 elif hasattr (x , '__array_namespace__' ):
122139 namespaces .add (x .__array_namespace__ (api_version = api_version ))
123140 else :
@@ -142,7 +159,7 @@ def _check_device(xp, device):
142159 if device not in ["cpu" , None ]:
143160 raise ValueError (f"Unsupported device for NumPy: { device !r} " )
144161
145- # device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
162+ # device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
146163# or cupy.ndarray. They are not included in array objects of this library
147164# because this library just reuses the respective ndarray classes without
148165# wrapping or subclassing them. These helper functions can be used instead of
@@ -162,8 +179,17 @@ def device(x: Array, /) -> Device:
162179 out: device
163180 a ``device`` object (see the "Device Support" section of the array API specification).
164181 """
165- if _is_numpy_array (x ):
182+ if is_numpy_array (x ):
166183 return "cpu"
184+ if is_jax_array (x ):
185+ # JAX has .device() as a method, but it is being deprecated so that it
186+ # can become a property, in accordance with the standard. In order for
187+ # this function to not break when JAX makes the flip, we check for
188+ # both here.
189+ if inspect .ismethod (x .device ):
190+ return x .device ()
191+ else :
192+ return x .device
167193 return x .device
168194
169195# Based on cupy.array_api.Array.to_device
@@ -231,24 +257,28 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
231257 .. note::
232258 If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
233259 """
234- if _is_numpy_array (x ):
260+ if is_numpy_array (x ):
235261 if stream is not None :
236262 raise ValueError ("The stream argument to to_device() is not supported" )
237263 if device == 'cpu' :
238264 return x
239265 raise ValueError (f"Unsupported device { device !r} " )
240- elif _is_cupy_array (x ):
266+ elif is_cupy_array (x ):
241267 # cupy does not yet have to_device
242268 return _cupy_to_device (x , device , stream = stream )
243- elif _is_torch_array (x ):
269+ elif is_torch_array (x ):
244270 return _torch_to_device (x , device , stream = stream )
245- elif _is_dask_array (x ):
271+ elif is_dask_array (x ):
246272 if stream is not None :
247273 raise ValueError ("The stream argument to to_device() is not supported" )
248274 # TODO: What if our array is on the GPU already?
249275 if device == 'cpu' :
250276 return x
251277 raise ValueError (f"Unsupported device { device !r} " )
278+ elif is_jax_array (x ):
279+ # This import adds to_device to x
280+ import jax .experimental .array_api # noqa: F401
281+ return x .to_device (device , stream = stream )
252282 return x .to_device (device , stream = stream )
253283
254284def size (x ):
0 commit comments