1818import inspect
1919import warnings
2020
21- def _is_jax_zero_gradient_array (x ) :
21+ def _is_jax_zero_gradient_array (x : object ) -> bool :
2222 """Return True if `x` is a zero-gradient array.
2323
2424 These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):
3232
3333 return isinstance (x , np .ndarray ) and x .dtype == jax .float0
3434
35- def is_numpy_array (x ):
35+
36+ def is_numpy_array (x : object ) -> bool :
3637 """
3738 Return True if `x` is a NumPy array.
3839
@@ -63,7 +64,8 @@ def is_numpy_array(x):
6364 return (isinstance (x , (np .ndarray , np .generic ))
6465 and not _is_jax_zero_gradient_array (x ))
6566
66- def is_cupy_array (x ):
67+
68+ def is_cupy_array (x : object ) -> bool :
6769 """
6870 Return True if `x` is a CuPy array.
6971
@@ -93,7 +95,8 @@ def is_cupy_array(x):
9395 # TODO: Should we reject ndarray subclasses?
9496 return isinstance (x , cp .ndarray )
9597
96- def is_torch_array (x ):
98+
99+ def is_torch_array (x : object ) -> bool :
97100 """
98101 Return True if `x` is a PyTorch tensor.
99102
@@ -120,7 +123,8 @@ def is_torch_array(x):
120123 # TODO: Should we reject ndarray subclasses?
121124 return isinstance (x , torch .Tensor )
122125
123- def is_ndonnx_array (x ):
126+
127+ def is_ndonnx_array (x : object ) -> bool :
124128 """
125129 Return True if `x` is a ndonnx Array.
126130
@@ -147,7 +151,8 @@ def is_ndonnx_array(x):
147151
148152 return isinstance (x , ndx .Array )
149153
150- def is_dask_array (x ):
154+
155+ def is_dask_array (x : object ) -> bool :
151156 """
152157 Return True if `x` is a dask.array Array.
153158
@@ -174,7 +179,8 @@ def is_dask_array(x):
174179
175180 return isinstance (x , dask .array .Array )
176181
177- def is_jax_array (x ):
182+
183+ def is_jax_array (x : object ) -> bool :
178184 """
179185 Return True if `x` is a JAX array.
180186
@@ -202,6 +208,7 @@ def is_jax_array(x):
202208
203209 return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204210
211+
205212def is_pydata_sparse_array (x ) -> bool :
206213 """
207214 Return True if `x` is an array from the `sparse` package.
@@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
231238 # TODO: Account for other backends.
232239 return isinstance (x , sparse .SparseArray )
233240
234- def is_array_api_obj (x ):
241+
242+ def is_array_api_obj (x : object ) -> bool :
235243 """
236244 Return True if `x` is an array API compatible array object.
237245
@@ -254,10 +262,12 @@ def is_array_api_obj(x):
254262 or is_pydata_sparse_array (x ) \
255263 or hasattr (x , '__array_namespace__' )
256264
257- def _compat_module_name ():
265+
266+ def _compat_module_name () -> str :
258267 assert __name__ .endswith ('.common._helpers' )
259268 return __name__ .removesuffix ('.common._helpers' )
260269
270+
261271def is_numpy_namespace (xp ) -> bool :
262272 """
263273 Returns True if `xp` is a NumPy namespace.
@@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
278288 """
279289 return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
280290
291+
281292def is_cupy_namespace (xp ) -> bool :
282293 """
283294 Returns True if `xp` is a CuPy namespace.
@@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
298309 """
299310 return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
300311
312+
301313def is_torch_namespace (xp ) -> bool :
302314 """
303315 Returns True if `xp` is a PyTorch namespace.
@@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
319331 return xp .__name__ in {'torch' , _compat_module_name () + '.torch' }
320332
321333
322- def is_ndonnx_namespace (xp ):
334+ def is_ndonnx_namespace (xp ) -> bool :
323335 """
324336 Returns True if `xp` is an NDONNX namespace.
325337
@@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
337349 """
338350 return xp .__name__ == 'ndonnx'
339351
340- def is_dask_namespace (xp ):
352+
353+ def is_dask_namespace (xp ) -> bool :
341354 """
342355 Returns True if `xp` is a Dask namespace.
343356
@@ -357,7 +370,8 @@ def is_dask_namespace(xp):
357370 """
358371 return xp .__name__ in {'dask.array' , _compat_module_name () + '.dask.array' }
359372
360- def is_jax_namespace (xp ):
373+
374+ def is_jax_namespace (xp ) -> bool :
361375 """
362376 Returns True if `xp` is a JAX namespace.
363377
@@ -378,7 +392,8 @@ def is_jax_namespace(xp):
378392 """
379393 return xp .__name__ in {'jax.numpy' , 'jax.experimental.array_api' }
380394
381- def is_pydata_sparse_namespace (xp ):
395+
396+ def is_pydata_sparse_namespace (xp ) -> bool :
382397 """
383398 Returns True if `xp` is a pydata/sparse namespace.
384399
@@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
396411 """
397412 return xp .__name__ == 'sparse'
398413
399- def is_array_api_strict_namespace (xp ):
414+
415+ def is_array_api_strict_namespace (xp ) -> bool :
400416 """
401417 Returns True if `xp` is an array-api-strict namespace.
402418
@@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
414430 """
415431 return xp .__name__ == 'array_api_strict'
416432
417- def _check_api_version (api_version ):
433+
434+ def _check_api_version (api_version : str ) -> None :
418435 if api_version in ['2021.12' , '2022.12' ]:
419436 warnings .warn (f"The { api_version } version of the array API specification was requested but the returned namespace is actually version 2023.12" )
420437 elif api_version is not None and api_version not in ['2021.12' , '2022.12' ,
421438 '2023.12' ]:
422439 raise ValueError ("Only the 2023.12 version of the array API specification is currently supported" )
423440
441+
424442def array_namespace (* xs , api_version = None , use_compat = None ):
425443 """
426444 Get the array API compatible namespace for the arrays `xs`.
@@ -808,9 +826,10 @@ def size(x: Array) -> int | None:
808826 return None if math .isnan (out ) else out
809827
810828
811- def is_writeable_array (x ) -> bool :
829+ def is_writeable_array (x : object ) -> bool :
812830 """
813831 Return False if ``x.__setitem__`` is expected to raise; True otherwise.
832+ Return False if `x` is not an array API compatible object.
814833
815834 Warning
816835 -------
@@ -821,10 +840,10 @@ def is_writeable_array(x) -> bool:
821840 return x .flags .writeable
822841 if is_jax_array (x ) or is_pydata_sparse_array (x ):
823842 return False
824- return True
843+ return is_array_api_obj ( x )
825844
826845
827- def is_lazy_array (x ) -> bool :
846+ def is_lazy_array (x : object ) -> bool :
828847 """Return True if x is potentially a future or it may be otherwise impossible or
829848 expensive to eagerly read its contents, regardless of their size, e.g. by
830849 calling ``bool(x)`` or ``float(x)``.
@@ -857,6 +876,9 @@ def is_lazy_array(x) -> bool:
857876 if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
858877 return True
859878
879+ if not is_array_api_obj (x ):
880+ return False
881+
860882 # Unknown Array API compatible object. Note that this test may have dire consequences
861883 # in terms of performance, e.g. for a lazy object that eagerly computes the graph
862884 # on __bool__ (dask is one such example, which however is special-cased above).
0 commit comments