Skip to content

Commit 931f89f

Browse files
jessegrabowskiricardoV94
authored andcommitted
importorskip tests on optional dependencies
1 parent e552245 commit 931f89f

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

tests/sampling/test_jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import pytensor.tensor as pt
2828
import pytest
2929

30-
from numpyro.infer import MCMC
3130
from pytensor.compile import SharedVariable
3231
from pytensor.graph import graph_inputs
3332

@@ -45,6 +44,8 @@
4544
sample_numpyro_nuts,
4645
)
4746

47+
MCMC = pytest.importorskip("numpyro.infer.MCMC")
48+
4849

4950
def test_jax_PosDefMatrix():
5051
x = pt.tensor(name="x", shape=(2, 2), dtype="float32")

tests/sampling/test_mcmc_external.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
7676

7777

7878
def test_step_args():
79+
pytest.importorskip("numpyro")
80+
7981
with Model() as model:
8082
a = Normal("a")
8183
idata = sample(
@@ -91,6 +93,9 @@ def test_step_args():
9193

9294
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
9395
def test_sample_var_names(nuts_sampler):
96+
if nuts_sampler != "pymc":
97+
pytest.importorskip(nuts_sampler)
98+
9499
seed = 1234
95100
kwargs = {
96101
"chains": 1,

0 commit comments

Comments
 (0)