33import numpy as np
44import pytest
55
6- from array_api_compat import array_namespace , at , is_dask_array , is_jax_array , is_writeable_array
6+ from array_api_compat import array_namespace , at , is_dask_array , is_jax_array , is_pydata_sparse_array , is_writeable_array
77from ._helpers import import_ , all_libraries
88
99
10+ def assert_array_equal (a , b ):
11+ if is_pydata_sparse_array (a ):
12+ a = a .todense ()
13+ elif is_dask_array (a ):
14+ a = a .compute ()
15+ np .testing .assert_array_equal (a , b )
16+
17+
1018@contextmanager
1119def assert_copy (x , copy : bool | None ):
1220 # dask arrays are writeable, but writing to them will hot-swap the
@@ -21,10 +29,13 @@ def assert_copy(x, copy: bool | None):
2129 x_orig = xp .asarray (x , copy = True )
2230 yield
2331
24- expect_copy = (
25- copy if copy is not None else (not is_writeable_array (x ) or is_dask_array (x ))
26- )
27- np .testing .assert_array_equal ((x == x_orig ).all (), expect_copy )
32+ if is_dask_array (x ):
33+ expect_copy = True
34+ elif copy is None :
35+ expect_copy = not is_writeable_array (x )
36+ else :
37+ expect_copy = copy
38+ assert_array_equal ((x == x_orig ).all (), expect_copy )
2839
2940
3041@pytest .fixture (params = all_libraries + ["np_readonly" ])
@@ -58,15 +69,15 @@ def test_operations(x, copy, op, arg, expect):
5869 with assert_copy (x , copy ):
5970 y = getattr (at (x , slice (1 , None )), op )(arg , copy = copy )
6071 assert isinstance (y , type (x ))
61- np . testing . assert_equal (y , expect )
72+ assert_array_equal (y , expect )
6273
6374
6475@pytest .mark .parametrize ("copy" , [True , False , None ])
6576def test_get (x , copy ):
6677 with assert_copy (x , copy ):
6778 y = at (x , slice (2 )).get (copy = copy )
6879 assert isinstance (y , type (x ))
69- np . testing . assert_array_equal (y , [10 , 20 ])
80+ assert_array_equal (y , [10 , 20 ])
7081 # Let assert_copy test that y is a view or copy
7182 with suppress ((TypeError , ValueError )):
7283 y [0 ] = 40
@@ -97,15 +108,15 @@ def test_get_fancy_indices(x, idx, wrap_index):
97108 with assert_copy (x , True ):
98109 y = at (x , [0 , 1 ]).get ()
99110 assert isinstance (y , type (x ))
100- np . testing . assert_array_equal (y , [10 , 20 ])
111+ assert_array_equal (y , [10 , 20 ])
101112 # Let assert_copy test that y is a view or copy
102113 with suppress ((TypeError , ValueError )):
103114 y [0 ] = 40
104115
105116 with assert_copy (x , True ):
106117 y = at (x , [0 , 1 ]).get (copy = None )
107118 assert isinstance (y , type (x ))
108- np . testing . assert_array_equal (y , [10 , 20 ])
119+ assert_array_equal (y , [10 , 20 ])
109120 # Let assert_copy test that y is a view or copy
110121 with suppress ((TypeError , ValueError )):
111122 y [0 ] = 40
@@ -119,7 +130,7 @@ def test_variant_index_syntax(x, copy):
119130 with assert_copy (x , copy ):
120131 y = at (x )[:2 ].set (40 , copy = copy )
121132 assert isinstance (y , type (x ))
122- np . testing . assert_array_equal (y , [40 , 40 , 30 ])
133+ assert_array_equal (y , [40 , 40 , 30 ])
123134 with pytest .raises (ValueError ):
124135 at (x , 1 )[2 ]
125136 with pytest .raises (ValueError ):
0 commit comments