1- from array_api_compat import (is_numpy_array , is_cupy_array , is_torch_array , # noqa: F401
2- is_dask_array , is_jax_array , is_pydata_sparse_array )
1+ from array_api_compat import ( # noqa: F401
2+ is_numpy_array , is_cupy_array , is_torch_array ,
3+ is_dask_array , is_jax_array , is_pydata_sparse_array ,
4+ is_numpy_namespace , is_cupy_namespace , is_torch_namespace ,
5+ is_dask_namespace , is_jax_namespace , is_pydata_sparse_namespace ,
6+ )
37
48from array_api_compat import is_array_api_obj , device , to_device
59
1014import array
1115from numpy .testing import assert_allclose
1216
13- is_functions = {
17+ is_array_functions = {
1418 'numpy' : 'is_numpy_array' ,
1519 'cupy' : 'is_cupy_array' ,
1620 'torch' : 'is_torch_array' ,
1923 'sparse' : 'is_pydata_sparse_array' ,
2024}
2125
22- @pytest .mark .parametrize ('library' , is_functions .keys ())
23- @pytest .mark .parametrize ('func' , is_functions .values ())
26+ is_namespace_functions = {
27+ 'numpy' : 'is_numpy_namespace' ,
28+ 'cupy' : 'is_cupy_namespace' ,
29+ 'torch' : 'is_torch_namespace' ,
30+ 'dask.array' : 'is_dask_namespace' ,
31+ 'jax.numpy' : 'is_jax_namespace' ,
32+ 'sparse' : 'is_pydata_sparse_namespace' ,
33+ }
34+
35+
36+ @pytest .mark .parametrize ('library' , is_array_functions .keys ())
37+ @pytest .mark .parametrize ('func' , is_array_functions .values ())
2438def test_is_xp_array (library , func ):
2539 lib = import_ (library )
2640 is_func = globals ()[func ]
2741
2842 x = lib .asarray ([1 , 2 , 3 ])
2943
30- assert is_func (x ) == (func == is_functions [library ])
44+ assert is_func (x ) == (func == is_array_functions [library ])
3145
3246 assert is_array_api_obj (x )
3347
48+
49+ @pytest .mark .parametrize ('library' , is_namespace_functions .keys ())
50+ @pytest .mark .parametrize ('func' , is_namespace_functions .values ())
51+ def test_is_xp_namespace (library , func ):
52+ lib = import_ (library )
53+ is_func = globals ()[func ]
54+
55+ assert is_func (lib ) == (func == is_namespace_functions [library ])
56+
57+
3458@pytest .mark .parametrize ("library" , all_libraries )
3559def test_device (library ):
3660 xp = import_ (library , wrapper = True )
@@ -64,8 +88,8 @@ def test_to_device_host(library):
6488 assert_allclose (x , expected )
6589
6690
67- @pytest .mark .parametrize ("target_library" , is_functions .keys ())
68- @pytest .mark .parametrize ("source_library" , is_functions .keys ())
91+ @pytest .mark .parametrize ("target_library" , is_array_functions .keys ())
92+ @pytest .mark .parametrize ("source_library" , is_array_functions .keys ())
6993def test_asarray_cross_library (source_library , target_library , request ):
7094 if source_library == "dask.array" and target_library == "torch" :
7195 # Allow rest of test to execute instead of immediately xfailing
@@ -81,7 +105,7 @@ def test_asarray_cross_library(source_library, target_library, request):
81105 pytest .skip (reason = "`sparse` does not allow implicit densification" )
82106 src_lib = import_ (source_library , wrapper = True )
83107 tgt_lib = import_ (target_library , wrapper = True )
84- is_tgt_type = globals ()[is_functions [target_library ]]
108+ is_tgt_type = globals ()[is_array_functions [target_library ]]
85109
86110 a = src_lib .asarray ([1 , 2 , 3 ])
87111 b = tgt_lib .asarray (a )
@@ -96,7 +120,7 @@ def test_asarray_copy(library):
96120 # should be able to delete this.
97121 xp = import_ (library , wrapper = True )
98122 asarray = xp .asarray
99- is_lib_func = globals ()[is_functions [library ]]
123+ is_lib_func = globals ()[is_array_functions [library ]]
100124 all = xp .all if library != 'dask.array' else lambda x : xp .all (x ).compute ()
101125
102126 if library == 'numpy' and xp .__version__ [0 ] < '2' and not hasattr (xp , '_CopyMode' ) :
0 commit comments