55# import jax
66import numpy as np
77import pytest
8- # import torch
8+ import torch
99import paddle
1010
1111import array_api_compat
@@ -73,11 +73,11 @@ def test_array_namespace(library, api_version, use_compat):
7373"""
7474 subprocess .run ([sys .executable , "-c" , code ], check = True )
7575
76- # def test_jax_zero_gradient():
77- # jx = jax.numpy.arange(4)
78- # jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
79- # assert (array_api_compat.get_namespace(jax_zero) is
80- # array_api_compat.get_namespace(jx))
76+ def test_jax_zero_gradient ():
77+ jx = jax .numpy .arange (4 )
78+ jax_zero = jax .vmap (jax .grad (jax .numpy .float32 , allow_int = True ))(jx )
79+ assert (array_api_compat .get_namespace (jax_zero ) is
80+ array_api_compat .get_namespace (jx ))
8181
8282def test_array_namespace_errors ():
8383 pytest .raises (TypeError , lambda : array_namespace ([1 ]))
@@ -87,53 +87,32 @@ def test_array_namespace_errors():
8787 pytest .raises (TypeError , lambda : array_namespace ((x , x )))
8888 pytest .raises (TypeError , lambda : array_namespace (x , (x , x )))
8989
90- # def test_array_namespace_errors_torch():
91- # y = torch.asarray([1, 2])
92- # x = np.asarray([1, 2])
93- # pytest.raises(TypeError, lambda: array_namespace(x, y))
90+ def test_array_namespace_errors_torch ():
91+ y = torch .asarray ([1 , 2 ])
92+ x = np .asarray ([1 , 2 ])
93+ pytest .raises (TypeError , lambda : array_namespace (x , y ))
9494
9595
9696def test_array_namespace_errors_paddle ():
9797 y = paddle .to_tensor ([1 , 2 ])
9898 x = np .asarray ([1 , 2 ])
9999 pytest .raises (TypeError , lambda : array_namespace (x , y ))
100100
101-
102- # def test_api_version():
103- # x = torch.asarray([1, 2])
104- # torch_ = import_("torch", wrapper=True)
105- # assert array_namespace(x, api_version="2023.12") == torch_
106- # assert array_namespace(x, api_version=None) == torch_
107- # assert array_namespace(x) == torch_
108- # # Should issue a warning
109- # with warnings.catch_warnings(record=True) as w:
110- # assert array_namespace(x, api_version="2021.12") == torch_
111- # assert len(w) == 1
112- # assert "2021.12" in str(w[0].message)
113-
114- # # Should issue a warning
115- # with warnings.catch_warnings(record=True) as w:
116- # assert array_namespace(x, api_version="2022.12") == torch_
117- # assert len(w) == 1
118- # assert "2022.12" in str(w[0].message)
119-
120- # pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
121-
122101def test_api_version ():
123- x = paddle .asarray ([1 , 2 ])
124- paddle_ = import_ ("paddle " , wrapper = True )
125- assert array_namespace (x , api_version = "2023.12" ) == paddle_
126- assert array_namespace (x , api_version = None ) == paddle_
127- assert array_namespace (x ) == paddle_
102+ x = torch .asarray ([1 , 2 ])
103+ torch_ = import_ ("torch " , wrapper = True )
104+ assert array_namespace (x , api_version = "2023.12" ) == torch_
105+ assert array_namespace (x , api_version = None ) == torch_
106+ assert array_namespace (x ) == torch_
128107 # Should issue a warning
129108 with warnings .catch_warnings (record = True ) as w :
130- assert array_namespace (x , api_version = "2021.12" ) == paddle_
109+ assert array_namespace (x , api_version = "2021.12" ) == torch_
131110 assert len (w ) == 1
132111 assert "2021.12" in str (w [0 ].message )
133112
134113 # Should issue a warning
135114 with warnings .catch_warnings (record = True ) as w :
136- assert array_namespace (x , api_version = "2022.12" ) == paddle_
115+ assert array_namespace (x , api_version = "2022.12" ) == torch_
137116 assert len (w ) == 1
138117 assert "2022.12" in str (w [0 ].message )
139118
0 commit comments