Skip to content

Commit 92426d6

Browse files
committed
allow dispatch of summary/inference network from type
* add tests for find_network
1 parent afef095 commit 92426d6

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

bayesflow/utils/dispatch/find_inference_network.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ def _(name: str, *args, **kwargs):
2929
raise ValueError(f"Unknown inference network: '{unknown_network}'")
3030

3131

32+
@find_inference_network.register
33+
def _(cls: type, *args, **kwargs):
34+
# Instantiate class with the given arguments
35+
network = cls(*args, **kwargs)
36+
return network
37+
38+
3239
@find_inference_network.register
3340
def _(layer: keras.Layer, *args, **kwargs):
3441
return layer

bayesflow/utils/dispatch/find_summary_network.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def _(name: str, *args, **kwargs):
3939
raise ValueError(f"Unknown summary network: '{unknown_network}'")
4040

4141

42+
@find_summary_network.register
43+
def _(cls: type, *args, **kwargs):
44+
# Instantiate class with the given arguments
45+
network = cls(*args, **kwargs)
46+
return network
47+
48+
4249
@find_summary_network.register
4350
def _(layer: keras.Layer, *args, **kwargs):
4451
return layer

tests/test_utils/test_dispatch.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,68 @@
11
import keras
22
import pytest
33

4-
from bayesflow.utils import find_inference_network, find_distribution, find_summary_network
4+
from bayesflow.utils import find_inference_network, find_distribution, find_network, find_summary_network
55
from bayesflow.experimental.diffusion_model import find_noise_schedule
66

7+
# --- Tests for find__network.py ---
8+
9+
10+
class DummyNetwork:
11+
def __init__(self, *a, **kw):
12+
self.args = a
13+
self.kwargs = kw
14+
15+
16+
@pytest.mark.parametrize(
17+
"name,expected_class_path",
18+
[
19+
("mlp", "bayesflow.networks.MLP"),
20+
],
21+
)
22+
def test_find_network_by_name(monkeypatch, name, expected_class_path):
23+
# patch the expected class in bayesflow.networks
24+
components = expected_class_path.split(".")
25+
module_path = ".".join(components[:-1])
26+
class_name = components[-1]
27+
28+
dummy_cls = DummyNetwork
29+
monkeypatch.setattr(f"{module_path}.{class_name}", dummy_cls)
30+
31+
net = find_network(name, 1, key="val")
32+
assert isinstance(net, DummyNetwork)
33+
assert net.args == (1,)
34+
assert net.kwargs == {"key": "val"}
35+
36+
37+
def test_find_network_by_type():
38+
# patch the expected class in bayesflow.networks
39+
net = find_network(DummyNetwork, 1, key="val")
40+
assert isinstance(net, DummyNetwork)
41+
assert net.args == (1,)
42+
assert net.kwargs == {"key": "val"}
43+
44+
45+
def test_find_network_by_keras_layer():
46+
layer = keras.layers.Dense(10)
47+
result = find_network(layer)
48+
assert result is layer
49+
50+
51+
def test_find_network_by_keras_model():
52+
model = keras.models.Sequential()
53+
result = find_network(model)
54+
assert result is model
55+
56+
57+
def test_find_network_unknown_name():
58+
with pytest.raises(ValueError):
59+
find_network("unknown_network_name")
60+
61+
62+
def test_find_network_invalid_type():
63+
with pytest.raises(TypeError):
64+
find_network(12345)
65+
766

867
# --- Tests for find_inference_network.py ---
968

@@ -37,6 +96,14 @@ def test_find_inference_network_by_name(monkeypatch, name, expected_class_path):
3796
assert net.kwargs == {"key": "val"}
3897

3998

99+
def test_find_inference_network_by_type():
100+
# patch the expected class in bayesflow.networks
101+
net = find_inference_network(DummyInferenceNetwork, 1, key="val")
102+
assert isinstance(net, DummyInferenceNetwork)
103+
assert net.args == (1,)
104+
assert net.kwargs == {"key": "val"}
105+
106+
40107
def test_find_inference_network_by_keras_layer():
41108
layer = keras.layers.Dense(10)
42109
result = find_inference_network(layer)
@@ -149,6 +216,14 @@ def test_find_summary_network_by_name(monkeypatch, name, expected_class_path):
149216
assert net.kwargs == {"flag": True}
150217

151218

219+
def test_find_summary_network_by_type():
220+
# patch the expected class in bayesflow.networks
221+
net = find_summary_network(DummySummaryNetwork, 1, key="val")
222+
assert isinstance(net, DummySummaryNetwork)
223+
assert net.args == (1,)
224+
assert net.kwargs == {"key": "val"}
225+
226+
152227
def test_find_summary_network_by_keras_layer():
153228
layer = keras.layers.Dense(1)
154229
out = find_summary_network(layer)

0 commit comments

Comments
 (0)