33
44import keras
55
6- from bayesflow .utils import sequential_kwargs
7- from bayesflow .utils .serialization import deserialize , serializable , serialize
6+ from bayesflow .utils import layer_kwargs
7+ from bayesflow .utils .serialization import serializable , serialize
88
9+ from ..sequential import Sequential
910from ..residual import Residual
1011
1112
1213@serializable ("bayesflow.networks" )
13- class MLP (keras . Sequential ):
14+ class MLP (Sequential ):
1415 """
1516 Implements a simple configurable MLP with optional residual connections and dropout.
1617
@@ -67,40 +68,19 @@ def __init__(
6768 self .norm = norm
6869 self .spectral_normalization = spectral_normalization
6970
70- layers = []
71+ blocks = []
7172
7273 for width in widths :
73- layer = self ._make_layer (
74+ block = self ._make_block (
7475 width , activation , kernel_initializer , residual , dropout , norm , spectral_normalization
7576 )
76- layers .append (layer )
77-
78- super ().__init__ (layers , ** sequential_kwargs (kwargs ))
79-
80- def build (self , input_shape = None ):
81- if self .built :
82- # building when the network is already built can cause issues with serialization
83- # see https://github.com/keras-team/keras/issues/21147
84- return
85-
86- # we only care about the last dimension, and using ... signifies to keras.Sequential
87- # that any number of batch dimensions is valid (which is what we want for all sublayers)
88- # we also have to avoid calling super().build() because this causes
89- # shape errors when building on non-sets but doing inference on sets
90- # this is a work-around for https://github.com/keras-team/keras/issues/21158
91- input_shape = (..., input_shape [- 1 ])
92-
93- for layer in self ._layers :
94- layer .build (input_shape )
95- input_shape = layer .compute_output_shape (input_shape )
77+ blocks .append (block )
9678
97- @classmethod
98- def from_config (cls , config , custom_objects = None ):
99- return cls (** deserialize (config , custom_objects = custom_objects ))
79+ super ().__init__ (* blocks , ** kwargs )
10080
10181 def get_config (self ):
10282 base_config = super ().get_config ()
103- base_config = sequential_kwargs (base_config )
83+ base_config = layer_kwargs (base_config )
10484
10585 config = {
10686 "widths" : self .widths ,
@@ -115,7 +95,7 @@ def get_config(self):
11595 return base_config | serialize (config )
11696
11797 @staticmethod
118- def _make_layer (width , activation , kernel_initializer , residual , dropout , norm , spectral_normalization ):
98+ def _make_block (width , activation , kernel_initializer , residual , dropout , norm , spectral_normalization ):
11999 layers = []
120100
121101 dense = keras .layers .Dense (width , kernel_initializer = kernel_initializer )
@@ -148,4 +128,4 @@ def _make_layer(width, activation, kernel_initializer, residual, dropout, norm,
148128 if residual :
149129 return Residual (* layers )
150130
151- return keras . Sequential (layers )
131+ return Sequential (layers )
0 commit comments