3232from array_api_extra ._lib ._utils ._typing import Array , Device
3333from array_api_extra .testing import lazy_xp_function
3434
35+ from .conftest import NUMPY_VERSION
36+
3537# some xp backends are untyped
3638# mypy: disable-error-code=no-untyped-def
3739
4850lazy_xp_function (sinc , static_argnames = "xp" )
4951
5052
51- NUMPY_GE2 = int (np .__version__ .split ("." )[0 ]) >= 2
52-
53-
54- @pytest .mark .skip_xp_backend (
55- Backend .SPARSE , reason = "read-only backend without .at support"
56- )
5753class TestApplyWhere :
5854 @staticmethod
5955 def f1 (x : Array , y : Array | int = 10 ) -> Array :
@@ -153,6 +149,14 @@ def test_dont_overwrite_fill_value(self, xp: ModuleType):
153149 xp_assert_equal (actual , xp .asarray ([100 , 12 ]))
154150 xp_assert_equal (fill_value , xp .asarray ([100 , 200 ]))
155151
152+ @pytest .mark .skip_xp_backend (
153+ Backend .ARRAY_API_STRICTEST ,
154+ reason = "no boolean indexing -> run everywhere" ,
155+ )
156+ @pytest .mark .skip_xp_backend (
157+ Backend .SPARSE ,
158+ reason = "no indexing by sparse array -> run everywhere" ,
159+ )
156160 def test_dont_run_on_false (self , xp : ModuleType ):
157161 x = xp .asarray ([1.0 , 2.0 , 0.0 ])
158162 y = xp .asarray ([0.0 , 3.0 , 4.0 ])
@@ -192,6 +196,7 @@ def test_device(self, xp: ModuleType, device: Device):
192196 y = apply_where (x % 2 == 0 , x , self .f1 , fill_value = x )
193197 assert get_device (y ) == device
194198
199+ @pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "no isdtype" )
195200 @pytest .mark .filterwarnings ("ignore::RuntimeWarning" ) # overflows, etc.
196201 @hypothesis .settings (
197202 # The xp and library fixtures are not regenerated between hypothesis iterations
@@ -217,8 +222,8 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
217222 library : Backend ,
218223 ):
219224 if (
220- library in (Backend .NUMPY , Backend . NUMPY_READONLY )
221- and not NUMPY_GE2
225+ library . like (Backend .NUMPY )
226+ and NUMPY_VERSION < ( 2 , 0 )
222227 and dtype is np .float32
223228 ):
224229 pytest .xfail (reason = "NumPy 1.x dtype promotion for scalars" )
@@ -562,6 +567,9 @@ def test_xp(self, xp: ModuleType):
562567 assert y .shape == (1 , 1 , 1 , 3 )
563568
564569
570+ @pytest .mark .filterwarnings ( # array_api_strictest
571+ "ignore:invalid value encountered:RuntimeWarning:array_api_strict"
572+ )
565573@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
566574class TestIsClose :
567575 @pytest .mark .parametrize ("swap" , [False , True ])
@@ -680,13 +688,15 @@ def test_bool_dtype(self, xp: ModuleType):
680688 isclose (xp .asarray (True ), b , atol = 1 ), xp .asarray ([True , True , True ])
681689 )
682690
691+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "unknown shape" )
683692 def test_none_shape (self , xp : ModuleType ):
684693 a = xp .asarray ([1 , 5 , 0 ])
685694 b = xp .asarray ([1 , 4 , 2 ])
686695 b = b [a < 5 ]
687696 a = a [a < 5 ]
688697 xp_assert_equal (isclose (a , b ), xp .asarray ([True , False ]))
689698
699+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "unknown shape" )
690700 def test_none_shape_bool (self , xp : ModuleType ):
691701 a = xp .asarray ([True , True , False ])
692702 b = xp .asarray ([True , False , True ])
@@ -819,8 +829,27 @@ def test_empty(self, xp: ModuleType):
819829 a = xp .asarray ([])
820830 xp_assert_equal (nunique (a ), xp .asarray (0 ))
821831
822- def test_device (self , xp : ModuleType , device : Device ):
823- a = xp .asarray (0.0 , device = device )
832+ def test_size1 (self , xp : ModuleType ):
833+ a = xp .asarray ([123 ])
834+ xp_assert_equal (nunique (a ), xp .asarray (1 ))
835+
836+ def test_all_equal (self , xp : ModuleType ):
837+ a = xp .asarray ([123 , 123 , 123 ])
838+ xp_assert_equal (nunique (a ), xp .asarray (1 ))
839+
840+ @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "No equal_nan kwarg in unique" )
841+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "sparse#855" )
842+ def test_nan (self , xp : ModuleType , library : Backend ):
843+ if library .like (Backend .NUMPY ) and NUMPY_VERSION < (1 , 24 ):
844+ pytest .xfail ("NumPy <1.24 has no equal_nan kwarg in unique" )
845+
846+ # Each NaN is counted separately
847+ a = xp .asarray ([xp .nan , 123.0 , xp .nan ])
848+ xp_assert_equal (nunique (a ), xp .asarray (3 ))
849+
850+ @pytest .mark .parametrize ("size" , [0 , 1 , 2 ])
851+ def test_device (self , xp : ModuleType , device : Device , size : int ):
852+ a = xp .asarray ([0.0 ] * size , device = device )
824853 assert get_device (nunique (a )) == device
825854
826855 def test_xp (self , xp : ModuleType ):
@@ -895,6 +924,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
895924
896925
897926@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no argsort" )
927+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "no unique_values" )
898928class TestSetDiff1D :
899929 @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "NaN-shaped arrays" )
900930 @pytest .mark .xfail_xp_backend (
0 commit comments