22import sys
33import warnings
44
5- import jax
65import numpy as np
76import pytest
8- import torch
97
108import array_api_compat
119from array_api_compat import array_namespace
@@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat):
7674 subprocess .run ([sys .executable , "-c" , code ], check = True )
7775
7876def test_jax_zero_gradient ():
77+ jax = import_ ("jax" )
7978 jx = jax .numpy .arange (4 )
8079 jax_zero = jax .vmap (jax .grad (jax .numpy .float32 , allow_int = True ))(jx )
8180 assert array_namespace (jax_zero ) is array_namespace (jx )
@@ -89,11 +88,13 @@ def test_array_namespace_errors():
8988 pytest .raises (TypeError , lambda : array_namespace (x , (x , x )))
9089
9190def test_array_namespace_errors_torch ():
91+ torch = import_ ("torch" )
9292 y = torch .asarray ([1 , 2 ])
9393 x = np .asarray ([1 , 2 ])
9494 pytest .raises (TypeError , lambda : array_namespace (x , y ))
9595
9696def test_api_version_torch ():
97+ torch = import_ ("torch" )
9798 x = torch .asarray ([1 , 2 ])
9899 torch_ = import_ ("torch" , wrapper = True )
99100 assert array_namespace (x , api_version = "2023.12" ) == torch_
@@ -118,6 +119,7 @@ def test_get_namespace():
118119 assert array_api_compat .get_namespace is array_namespace
119120
120121def test_python_scalars ():
122+ torch = import_ ("torch" )
121123 a = torch .asarray ([1 , 2 ])
122124 xp = import_ ("torch" , wrapper = True )
123125
0 commit comments