1- import contextlib
21import math
32import warnings
43from types import ModuleType
2423from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
2524from array_api_extra ._lib ._utils ._compat import device as get_device
2625from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
27- from array_api_extra ._lib ._utils ._typing import Array , Device
26+ from array_api_extra ._lib ._utils ._typing import Device
2827from array_api_extra .testing import lazy_xp_function
2928
3029# some xp backends are untyped
@@ -291,22 +290,12 @@ def test_xp(self, xp: ModuleType):
291290
292291class TestExpandDims :
293292 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
294- @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "tuple index out of range" )
295- @pytest .mark .xfail_xp_backend (Backend .TORCH , reason = "tuple index out of range" )
296- def test_functionality (self , xp : ModuleType ):
297- def _squeeze_all (b : Array ) -> Array :
298- """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
299- for axis in range (b .ndim ):
300- with contextlib .suppress (ValueError ):
301- b = xp .squeeze (b , axis = axis )
302- return b
303-
304- s = (2 , 3 , 4 , 5 )
305- a = xp .empty (s )
293+ def test_single_axis (self , xp : ModuleType ):
294+ """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
295+ a = xp .empty ((2 , 3 , 4 , 5 ))
306296 for axis in range (- 5 , 4 ):
307297 b = expand_dims (a , axis = axis )
308- assert b .shape [axis ] == 1
309- assert _squeeze_all (b ).shape == s
298+ xp_assert_equal (b , xp .expand_dims (a , axis = axis ))
310299
311300 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
312301 def test_axis_tuple (self , xp : ModuleType ):
@@ -317,8 +306,7 @@ def test_axis_tuple(self, xp: ModuleType):
317306 assert expand_dims (a , axis = (0 , - 3 , - 5 )).shape == (1 , 1 , 3 , 1 , 3 , 3 )
318307
319308 def test_axis_out_of_range (self , xp : ModuleType ):
320- s = (2 , 3 , 4 , 5 )
321- a = xp .empty (s )
309+ a = xp .empty ((2 , 3 , 4 , 5 ))
322310 with pytest .raises (IndexError , match = "out of bounds" ):
323311 _ = expand_dims (a , axis = - 6 )
324312 with pytest .raises (IndexError , match = "out of bounds" ):
0 commit comments