File tree Expand file tree Collapse file tree 2 files changed +20
-1
lines changed Expand file tree Collapse file tree 2 files changed +20
-1
lines changed Original file line number Diff line number Diff line change @@ -91,7 +91,7 @@ def is_cupy_array(x):
9191 import cupy as cp
9292
9393 # TODO: Should we reject ndarray subclasses?
94- return isinstance (x , ( cp .ndarray , cp . generic ) )
94+ return isinstance (x , cp .ndarray )
9595
9696def is_torch_array (x ):
9797 """
Original file line number Diff line number Diff line change @@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func):
5555 assert is_func (lib ) == (func == is_namespace_functions [library ])
5656
5757
58+ @pytest .mark .parametrize ('library' , all_libraries )
59+ def test_xp_is_array_generics (library ):
60+ """
61+ Test that scalar selection on a xp.ndarray always returns
62+ an object that matches with exactly one among the is_*_array
63+ function of the same library and is_numpy_array.
64+ """
65+ lib = import_ (library )
66+ x = lib .asarray ([1 , 2 , 3 ])
67+ x0 = x [0 ]
68+
69+ matches = []
70+ for library2 , func in is_array_functions .items ():
71+ is_func = globals ()[func ]
72+ if is_func (x0 ):
73+ matches .append (library2 )
74+ assert matches in ([library ], ["numpy" ])
75+
76+
5877@pytest .mark .parametrize ("library" , all_libraries )
5978def test_device (library ):
6079 xp = import_ (library , wrapper = True )
You can’t perform that action at this time.
0 commit comments