Skip to content

Commit 1fe8c94

Browse files
LarsKueCopilot
andauthored
Add a custom Sequential network to avoid issues with building and serialization in keras (#493)
* add custom sequential to fix #491 * revert using Sequential in classifier_two_sample_test.py * Add docstring to custom Sequential Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix copilot docstring * remove mlp override methods --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent cd2c212 commit 1fe8c94

File tree

7 files changed

+123
-36
lines changed

7 files changed

+123
-36
lines changed

bayesflow/networks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .point_inference_network import PointInferenceNetwork
1313
from .mlp import MLP
1414
from .fusion_network import FusionNetwork
15+
from .sequential import Sequential
1516
from .summary_network import SummaryNetwork
1617
from .time_series_network import TimeSeriesNetwork
1718
from .transformers import SetTransformer, TimeSeriesTransformer, FusionTransformer

bayesflow/networks/mlp/mlp.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33

44
import keras
55

6-
from bayesflow.utils import sequential_kwargs
7-
from bayesflow.utils.serialization import deserialize, serializable, serialize
6+
from bayesflow.utils import layer_kwargs
7+
from bayesflow.utils.serialization import serializable, serialize
88

9+
from ..sequential import Sequential
910
from ..residual import Residual
1011

1112

1213
@serializable("bayesflow.networks")
13-
class MLP(keras.Sequential):
14+
class MLP(Sequential):
1415
"""
1516
Implements a simple configurable MLP with optional residual connections and dropout.
1617
@@ -67,40 +68,19 @@ def __init__(
6768
self.norm = norm
6869
self.spectral_normalization = spectral_normalization
6970

70-
layers = []
71+
blocks = []
7172

7273
for width in widths:
73-
layer = self._make_layer(
74+
block = self._make_block(
7475
width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization
7576
)
76-
layers.append(layer)
77-
78-
super().__init__(layers, **sequential_kwargs(kwargs))
79-
80-
def build(self, input_shape=None):
81-
if self.built:
82-
# building when the network is already built can cause issues with serialization
83-
# see https://github.com/keras-team/keras/issues/21147
84-
return
85-
86-
# we only care about the last dimension, and using ... signifies to keras.Sequential
87-
# that any number of batch dimensions is valid (which is what we want for all sublayers)
88-
# we also have to avoid calling super().build() because this causes
89-
# shape errors when building on non-sets but doing inference on sets
90-
# this is a work-around for https://github.com/keras-team/keras/issues/21158
91-
input_shape = (..., input_shape[-1])
92-
93-
for layer in self._layers:
94-
layer.build(input_shape)
95-
input_shape = layer.compute_output_shape(input_shape)
77+
blocks.append(block)
9678

97-
@classmethod
98-
def from_config(cls, config, custom_objects=None):
99-
return cls(**deserialize(config, custom_objects=custom_objects))
79+
super().__init__(*blocks, **kwargs)
10080

10181
def get_config(self):
10282
base_config = super().get_config()
103-
base_config = sequential_kwargs(base_config)
83+
base_config = layer_kwargs(base_config)
10484

10585
config = {
10686
"widths": self.widths,
@@ -115,7 +95,7 @@ def get_config(self):
11595
return base_config | serialize(config)
11696

11797
@staticmethod
118-
def _make_layer(width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization):
98+
def _make_block(width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization):
11999
layers = []
120100

121101
dense = keras.layers.Dense(width, kernel_initializer=kernel_initializer)
@@ -148,4 +128,4 @@ def _make_layer(width, activation, kernel_initializer, residual, dropout, norm,
148128
if residual:
149129
return Residual(*layers)
150130

151-
return keras.Sequential(layers)
131+
return Sequential(layers)

bayesflow/networks/residual/residual.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from bayesflow.utils import sequential_kwargs
77
from bayesflow.utils.serialization import deserialize, serializable, serialize
88

9+
from ..sequential import Sequential
10+
911

1012
@serializable("bayesflow.networks")
11-
class Residual(keras.Sequential):
13+
class Residual(Sequential):
1214
def __init__(self, *layers: keras.Layer, **kwargs):
1315
if len(layers) == 1 and isinstance(layers[0], Sequence):
1416
layers = layers[0]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .sequential import Sequential
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from collections.abc import Sequence
2+
import keras
3+
4+
from bayesflow.utils import layer_kwargs
5+
from bayesflow.utils.serialization import deserialize, serializable, serialize
6+
7+
8+
@serializable("bayesflow.networks")
9+
class Sequential(keras.Layer):
10+
"""
11+
A custom sequential model for managing a sequence of Keras layers.
12+
13+
This class extends `keras.Layer` and provides functionality for building,
14+
calling, and serializing a sequence of layers. Unlike `keras.Sequential`,
15+
this implementation does not eagerly check input shapes, meaning it is
16+
compatible with both single inputs and sets.
17+
18+
Parameters
19+
----------
20+
layers : keras.layer | Sequence[keras.layer]
21+
A sequence of Keras layers to be managed by this model.
22+
Can be passed by unpacking or as a single sequence.
23+
**kwargs :
24+
Additional keyword arguments passed to the base `keras.Layer` class.
25+
26+
Notes
27+
-----
28+
- This class differs from `keras.Sequential` in that it does not eagerly check
29+
input shapes. This means that it is compatible with both single inputs
30+
and sets.
31+
"""
32+
33+
def __init__(self, *layers: keras.Layer | Sequence[keras.Layer], **kwargs):
34+
super().__init__(**layer_kwargs(kwargs))
35+
if len(layers) == 1 and isinstance(layers[0], Sequence):
36+
layers = layers[0]
37+
38+
self._layers = layers
39+
40+
def build(self, input_shape):
41+
if self.built:
42+
# building when the network is already built can cause issues with serialization
43+
# see https://github.com/keras-team/keras/issues/21147
44+
return
45+
46+
for layer in self._layers:
47+
layer.build(input_shape)
48+
input_shape = layer.compute_output_shape(input_shape)
49+
50+
def call(self, inputs, training=None, mask=None):
51+
x = inputs
52+
for layer in self._layers:
53+
kwargs = self._make_kwargs_for_layer(layer, training, mask)
54+
x = layer(x, **kwargs)
55+
return x
56+
57+
def compute_output_shape(self, input_shape):
58+
for layer in self._layers:
59+
input_shape = layer.compute_output_shape(input_shape)
60+
61+
return input_shape
62+
63+
def get_config(self):
64+
base_config = super().get_config()
65+
base_config = layer_kwargs(base_config)
66+
67+
config = {
68+
"layers": [serialize(layer) for layer in self._layers],
69+
}
70+
71+
return base_config | config
72+
73+
@classmethod
74+
def from_config(cls, config, custom_objects=None):
75+
return cls(**deserialize(config, custom_objects=custom_objects))
76+
77+
@property
78+
def layers(self):
79+
return self._layers
80+
81+
@staticmethod
82+
def _make_kwargs_for_layer(layer, training, mask):
83+
kwargs = {}
84+
if layer._call_has_mask_arg:
85+
kwargs["mask"] = mask
86+
if layer._call_has_training_arg and training is not None:
87+
kwargs["training"] = training
88+
return kwargs

tests/test_networks/test_mlp/conftest.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,24 @@
33
from bayesflow.networks import MLP
44

55

6+
@pytest.fixture(params=[None, 0.0, 0.1])
7+
def dropout(request):
8+
return request.param
9+
10+
11+
@pytest.fixture(params=[None, "batch"])
12+
def norm(request):
13+
return request.param
14+
15+
16+
@pytest.fixture(params=[False, True])
17+
def residual(request):
18+
return request.param
19+
20+
621
@pytest.fixture()
7-
def mlp():
8-
return MLP([64, 64])
22+
def mlp(dropout, norm, residual):
23+
return MLP([64, 64], dropout=dropout, norm=norm, residual=residual)
924

1025

1126
@pytest.fixture()

tests/test_networks/test_mlp/test_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from bayesflow.utils.serialization import deserialize, serialize
44

5-
from ...utils import assert_models_equal
5+
from ...utils import assert_layers_equal
66

77

88
def test_serialize_deserialize(mlp, build_shapes):
@@ -21,4 +21,4 @@ def test_save_and_load(tmp_path, mlp, build_shapes):
2121
keras.saving.save_model(mlp, tmp_path / "model.keras")
2222
loaded = keras.saving.load_model(tmp_path / "model.keras")
2323

24-
assert_models_equal(mlp, loaded)
24+
assert_layers_equal(mlp, loaded)

0 commit comments

Comments
 (0)