@@ -202,7 +202,6 @@ def is_jax_array(x):
202202
203203 return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204204
205-
206205def is_pydata_sparse_array (x ) -> bool :
207206 """
208207 Return True if `x` is an array from the `sparse` package.
@@ -255,6 +254,166 @@ def is_array_api_obj(x):
255254 or is_pydata_sparse_array (x ) \
256255 or hasattr (x , '__array_namespace__' )
257256
257+ def _compat_module_name ():
258+ assert __name__ .endswith ('.common._helpers' )
259+ return __name__ .removesuffix ('.common._helpers' )
260+
261+ def is_numpy_namespace (xp ) -> bool :
262+ """
263+ Returns True if `xp` is a NumPy namespace.
264+
265+ This includes both NumPy itself and the version wrapped by array-api-compat.
266+
267+ See Also
268+ --------
269+
270+ array_namespace
271+ is_cupy_namespace
272+ is_torch_namespace
273+ is_ndonnx_namespace
274+ is_dask_namespace
275+ is_jax_namespace
276+ is_pydata_sparse_namespace
277+ is_array_api_strict_namespace
278+ """
279+ return xp .__name__ in {'numpy' , _compat_module_name + '.numpy' }
280+
281+ def is_cupy_namespace (xp ) -> bool :
282+ """
283+ Returns True if `xp` is a CuPy namespace.
284+
285+ This includes both CuPy itself and the version wrapped by array-api-compat.
286+
287+ See Also
288+ --------
289+
290+ array_namespace
291+ is_numpy_namespace
292+ is_torch_namespace
293+ is_ndonnx_namespace
294+ is_dask_namespace
295+ is_jax_namespace
296+ is_pydata_sparse_namespace
297+ is_array_api_strict_namespace
298+ """
299+ return xp .__name__ in {'cupy' , _compat_module_name + '.cupy' }
300+
301+ def is_torch_namespace (xp ) -> bool :
302+ """
303+ Returns True if `xp` is a PyTorch namespace.
304+
305+ This includes both PyTorch itself and the version wrapped by array-api-compat.
306+
307+ See Also
308+ --------
309+
310+ array_namespace
311+ is_numpy_namespace
312+ is_cupy_namespace
313+ is_ndonnx_namespace
314+ is_dask_namespace
315+ is_jax_namespace
316+ is_pydata_sparse_namespace
317+ is_array_api_strict_namespace
318+ """
319+ return xp .__name__ in {'torch' , _compat_module_name + '.torch' }
320+
321+
322+ def is_ndonnx_namespace (xp ):
323+ """
324+ Returns True if `xp` is an NDONNX namespace.
325+
326+ See Also
327+ --------
328+
329+ array_namespace
330+ is_numpy_namespace
331+ is_cupy_namespace
332+ is_torch_namespace
333+ is_dask_namespace
334+ is_jax_namespace
335+ is_pydata_sparse_namespace
336+ is_array_api_strict_namespace
337+ """
338+ return xp .__name__ == 'ndonnx'
339+
340+ def is_dask_namespace (xp ):
341+ """
342+ Returns True if `xp` is a Dask namespace.
343+
344+ This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
345+
346+ See Also
347+ --------
348+
349+ array_namespace
350+ is_numpy_namespace
351+ is_cupy_namespace
352+ is_torch_namespace
353+ is_ndonnx_namespace
354+ is_jax_namespace
355+ is_pydata_sparse_namespace
356+ is_array_api_strict_namespace
357+ """
358+ return xp .__name__ in {'dask.array' , _compat_module_name + '.dask.array' }
359+
360+ def is_jax_namespace (xp ):
361+ """
362+ Returns True if `xp` is a JAX namespace.
363+
364+ This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
365+ older versions of JAX.
366+
367+ See Also
368+ --------
369+
370+ array_namespace
371+ is_numpy_namespace
372+ is_cupy_namespace
373+ is_torch_namespace
374+ is_ndonnx_namespace
375+ is_dask_namespace
376+ is_pydata_sparse_namespace
377+ is_array_api_strict_namespace
378+ """
379+ return xp .__name__ in {'jax.numpy' , 'jax.experimental.array_api' }
380+
381+ def is_pydata_sparse_namespace (xp ):
382+ """
383+ Returns True if `xp` is a pydata/sparse namespace.
384+
385+ See Also
386+ --------
387+
388+ array_namespace
389+ is_numpy_namespace
390+ is_cupy_namespace
391+ is_torch_namespace
392+ is_ndonnx_namespace
393+ is_dask_namespace
394+ is_jax_namespace
395+ is_array_api_strict_namespace
396+ """
397+ return xp .__name__ == 'sparse'
398+
399+ def is_array_api_strict_namespace (xp ):
400+ """
401+ Returns True if `xp` is an array-api-strict namespace.
402+
403+ See Also
404+ --------
405+
406+ array_namespace
407+ is_numpy_namespace
408+ is_cupy_namespace
409+ is_torch_namespace
410+ is_ndonnx_namespace
411+ is_dask_namespace
412+ is_jax_namespace
413+ is_pydata_sparse_namespace
414+ """
415+ return xp .__name__ == 'array_api_strict'
416+
258417def _check_api_version (api_version ):
259418 if api_version == '2021.12' :
260419 warnings .warn ("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12" )
@@ -643,13 +802,21 @@ def size(x):
643802 "device" ,
644803 "get_namespace" ,
645804 "is_array_api_obj" ,
805+ "is_array_api_strict_namespace" ,
646806 "is_cupy_array" ,
807+ "is_cupy_namespace" ,
647808 "is_dask_array" ,
809+ "is_dask_namespace" ,
648810 "is_jax_array" ,
811+ "is_jax_namespace" ,
649812 "is_numpy_array" ,
813+ "is_numpy_namespace" ,
650814 "is_torch_array" ,
815+ "is_torch_namespace" ,
651816 "is_ndonnx_array" ,
817+ "is_ndonnx_namespace" ,
652818 "is_pydata_sparse_array" ,
819+ "is_pydata_sparse_namespace" ,
653820 "size" ,
654821 "to_device" ,
655822]
0 commit comments