Skip to content

Commit b75c5f5

Browse files
add unit tests for single level graphical model
1 parent 9284333 commit b75c5f5

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

tests/test_simulators/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,10 @@ def fixed_mu():
247247
)
248248
def simulator(request):
249249
return request.getfixturevalue(request.param)
250+
251+
252+
@pytest.fixture()
253+
def single_level_simulator():
254+
from bayesflow.experimental.graphical_simulator.example_simulators import single_level
255+
256+
return single_level()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import numpy as np
2+
3+
import bayesflow as bf
4+
5+
6+
def test_single_level_simulator(single_level_simulator):
7+
assert isinstance(single_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
8+
assert isinstance(single_level_simulator.sample(5), dict)
9+
10+
samples = single_level_simulator.sample((12,))
11+
expected_keys = ["N", "beta", "sigma", "x", "y"]
12+
13+
assert set(samples.keys()) == set(expected_keys)
14+
assert 5 <= samples["N"] < 15
15+
assert np.shape(samples["beta"]) == (12, 2) # num_samples, beta_dim
16+
assert np.shape(samples["sigma"]) == (12, 1) # num_samples, sigma_dim
17+
assert np.shape(samples["x"]) == (12, samples["N"])
18+
assert np.shape(samples["y"]) == (12, samples["N"])

0 commit comments

Comments
 (0)