|
1 | 1 | import keras |
2 | 2 | import pytest |
3 | 3 |
|
4 | | -from bayesflow.utils import find_inference_network, find_distribution, find_summary_network |
| 4 | +from bayesflow.utils import find_inference_network, find_distribution, find_network, find_summary_network |
5 | 5 | from bayesflow.experimental.diffusion_model import find_noise_schedule |
6 | 6 |
|
| 7 | +# --- Tests for find__network.py --- |
| 8 | + |
| 9 | + |
| 10 | +class DummyNetwork: |
| 11 | + def __init__(self, *a, **kw): |
| 12 | + self.args = a |
| 13 | + self.kwargs = kw |
| 14 | + |
| 15 | + |
| 16 | +@pytest.mark.parametrize( |
| 17 | + "name,expected_class_path", |
| 18 | + [ |
| 19 | + ("mlp", "bayesflow.networks.MLP"), |
| 20 | + ], |
| 21 | +) |
| 22 | +def test_find_network_by_name(monkeypatch, name, expected_class_path): |
| 23 | + # patch the expected class in bayesflow.networks |
| 24 | + components = expected_class_path.split(".") |
| 25 | + module_path = ".".join(components[:-1]) |
| 26 | + class_name = components[-1] |
| 27 | + |
| 28 | + dummy_cls = DummyNetwork |
| 29 | + monkeypatch.setattr(f"{module_path}.{class_name}", dummy_cls) |
| 30 | + |
| 31 | + net = find_network(name, 1, key="val") |
| 32 | + assert isinstance(net, DummyNetwork) |
| 33 | + assert net.args == (1,) |
| 34 | + assert net.kwargs == {"key": "val"} |
| 35 | + |
| 36 | + |
| 37 | +def test_find_network_by_type(): |
| 38 | + # patch the expected class in bayesflow.networks |
| 39 | + net = find_network(DummyNetwork, 1, key="val") |
| 40 | + assert isinstance(net, DummyNetwork) |
| 41 | + assert net.args == (1,) |
| 42 | + assert net.kwargs == {"key": "val"} |
| 43 | + |
| 44 | + |
| 45 | +def test_find_network_by_keras_layer(): |
| 46 | + layer = keras.layers.Dense(10) |
| 47 | + result = find_network(layer) |
| 48 | + assert result is layer |
| 49 | + |
| 50 | + |
| 51 | +def test_find_network_by_keras_model(): |
| 52 | + model = keras.models.Sequential() |
| 53 | + result = find_network(model) |
| 54 | + assert result is model |
| 55 | + |
| 56 | + |
| 57 | +def test_find_network_unknown_name(): |
| 58 | + with pytest.raises(ValueError): |
| 59 | + find_network("unknown_network_name") |
| 60 | + |
| 61 | + |
| 62 | +def test_find_network_invalid_type(): |
| 63 | + with pytest.raises(TypeError): |
| 64 | + find_network(12345) |
| 65 | + |
7 | 66 |
|
8 | 67 | # --- Tests for find_inference_network.py --- |
9 | 68 |
|
@@ -37,6 +96,14 @@ def test_find_inference_network_by_name(monkeypatch, name, expected_class_path): |
37 | 96 | assert net.kwargs == {"key": "val"} |
38 | 97 |
|
39 | 98 |
|
| 99 | +def test_find_inference_network_by_type(): |
| 100 | + # patch the expected class in bayesflow.networks |
| 101 | + net = find_inference_network(DummyInferenceNetwork, 1, key="val") |
| 102 | + assert isinstance(net, DummyInferenceNetwork) |
| 103 | + assert net.args == (1,) |
| 104 | + assert net.kwargs == {"key": "val"} |
| 105 | + |
| 106 | + |
40 | 107 | def test_find_inference_network_by_keras_layer(): |
41 | 108 | layer = keras.layers.Dense(10) |
42 | 109 | result = find_inference_network(layer) |
@@ -149,6 +216,14 @@ def test_find_summary_network_by_name(monkeypatch, name, expected_class_path): |
149 | 216 | assert net.kwargs == {"flag": True} |
150 | 217 |
|
151 | 218 |
|
| 219 | +def test_find_summary_network_by_type(): |
| 220 | + # patch the expected class in bayesflow.networks |
| 221 | + net = find_summary_network(DummySummaryNetwork, 1, key="val") |
| 222 | + assert isinstance(net, DummySummaryNetwork) |
| 223 | + assert net.args == (1,) |
| 224 | + assert net.kwargs == {"key": "val"} |
| 225 | + |
| 226 | + |
152 | 227 | def test_find_summary_network_by_keras_layer(): |
153 | 228 | layer = keras.layers.Dense(1) |
154 | 229 | out = find_summary_network(layer) |
|
0 commit comments