@@ -178,7 +178,7 @@ def _check_api_version(api_version):
178178 elif api_version is not None and api_version != '2022.12' :
179179 raise ValueError ("Only the 2022.12 version of the array API specification is currently supported" )
180180
181- def array_namespace (* xs , api_version = None , _use_compat = True ):
181+ def array_namespace (* xs , api_version = None , use_compat = None ):
182182 """
183183 Get the array API compatible namespace for the arrays `xs`.
184184
@@ -191,6 +191,12 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
191191 The newest version of the spec that you need support for (currently
192192 the compat library wrapped APIs support v2022.12).
193193
194+ use_compat: bool or None
195+ If None (the default), the native namespace will be returned if it is
196+ already array API compatible, otherwise a compat wrapper is used. If
197+ True, the compat library wrapped library will be returned. If False,
198+ the native library namespace is returned.
199+
194200 Returns
195201 -------
196202
@@ -234,46 +240,66 @@ def your_function(x, y):
234240 is_jax_array
235241
236242 """
243+ if use_compat not in [None , True , False ]:
244+ raise ValueError ("use_compat must be None, True, or False" )
245+
246+ _use_compat = use_compat in [None , True ]
247+
237248 namespaces = set ()
238249 for x in xs :
239250 if is_numpy_array (x ):
240- _check_api_version (api_version )
241- if _use_compat :
242- from .. import numpy as numpy_namespace
251+ from .. import numpy as numpy_namespace
252+ import numpy as np
253+ if use_compat is True :
254+ _check_api_version (api_version )
243255 namespaces .add (numpy_namespace )
244- else :
245- import numpy as np
256+ elif use_compat is False :
246257 namespaces .add (np )
258+ else :
259+ # numpy 2.0 has __array_namespace__ and is fully array API
260+ # compatible.
261+ if hasattr (x , '__array_namespace__' ):
262+ namespaces .add (x .__array_namespace__ (api_version = api_version ))
263+ else :
264+ namespaces .add (numpy_namespace )
247265 elif is_cupy_array (x ):
248- _check_api_version (api_version )
249266 if _use_compat :
267+ _check_api_version (api_version )
250268 from .. import cupy as cupy_namespace
251269 namespaces .add (cupy_namespace )
252270 else :
253271 import cupy as cp
254272 namespaces .add (cp )
255273 elif is_torch_array (x ):
256- _check_api_version (api_version )
257274 if _use_compat :
275+ _check_api_version (api_version )
258276 from .. import torch as torch_namespace
259277 namespaces .add (torch_namespace )
260278 else :
261279 import torch
262280 namespaces .add (torch )
263281 elif is_dask_array (x ):
264- _check_api_version (api_version )
265282 if _use_compat :
283+ _check_api_version (api_version )
266284 from ..dask import array as dask_namespace
267285 namespaces .add (dask_namespace )
268286 else :
269- raise TypeError ("_use_compat cannot be False if input array is a dask array!" )
287+ import dask .array as da
288+ namespaces .add (da )
270289 elif is_jax_array (x ):
271- _check_api_version (api_version )
272- # jax.experimental.array_api is already an array namespace. We do
273- # not have a wrapper submodule for it.
274- import jax .experimental .array_api as jnp
290+ if use_compat is True :
291+ _check_api_version (api_version )
292+ raise ValueError ("JAX does not have an array-api-compat wrapper" )
293+ elif use_compat is False :
294+ import jax .numpy as jnp
295+ else :
296+ # jax.experimental.array_api is already an array namespace. We do
297+ # not have a wrapper submodule for it.
298+ import jax .experimental .array_api as jnp
275299 namespaces .add (jnp )
276300 elif hasattr (x , '__array_namespace__' ):
301+ if use_compat is True :
302+ raise ValueError ("The given array does not have an array-api-compat wrapper" )
277303 namespaces .add (x .__array_namespace__ (api_version = api_version ))
278304 else :
279305 # TODO: Support Python scalars?
0 commit comments