Skip to content

Commit d9e9782

Browse files
arrjonvpratz
andauthored
Allow separate inputs to subnets for continuous models (#521)
Introduces easy access to the different inputs x, t and conditions, to allow for specialized processing of each input, which can be beneficial for more advanced use cases. --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de>
1 parent 47d2766 commit d9e9782

File tree

5 files changed

+280
-57
lines changed

5 files changed

+280
-57
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
7+
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as
88
from bayesflow.utils.serialization import deserialize, serializable, serialize
99

1010
from ..inference_network import InferenceNetwork
@@ -67,6 +67,11 @@ def __init__(
6767
Final number of discretization steps
6868
subnet_kwargs: dict[str, any], optional
6969
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
70+
concatenate_subnet_input: bool, optional
71+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
72+
into a single vector or passed as separate arguments. If set to False, the subnet
73+
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
74+
and optional 'conditions'. Default is True.
7075
**kwargs : dict, optional, default: {}
7176
Additional keyword arguments
7277
"""
@@ -77,6 +82,7 @@ def __init__(
7782
subnet_kwargs = subnet_kwargs or {}
7883
if subnet == "mlp":
7984
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
85+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
8086

8187
self.subnet = find_network(subnet, **subnet_kwargs)
8288
self.output_projector = keras.layers.Dense(
@@ -119,6 +125,7 @@ def get_config(self):
119125
"eps": self.eps,
120126
"s0": self.s0,
121127
"s1": self.s1,
128+
"concatenate_subnet_input": self._concatenate_subnet_input,
122129
# we do not need to store subnet_kwargs
123130
}
124131

@@ -161,18 +168,23 @@ def build(self, xz_shape, conditions_shape=None):
161168

162169
input_shape = list(xz_shape)
163170

164-
# time vector
165-
input_shape[-1] += 1
171+
if self._concatenate_subnet_input:
172+
# construct time vector
173+
input_shape[-1] += 1
174+
if conditions_shape is not None:
175+
input_shape[-1] += conditions_shape[-1]
176+
input_shape = tuple(input_shape)
166177

167-
if conditions_shape is not None:
168-
input_shape[-1] += conditions_shape[-1]
169-
170-
input_shape = tuple(input_shape)
171-
172-
self.subnet.build(input_shape)
173-
174-
input_shape = self.subnet.compute_output_shape(input_shape)
175-
self.output_projector.build(input_shape)
178+
self.subnet.build(input_shape)
179+
out_shape = self.subnet.compute_output_shape(input_shape)
180+
else:
181+
# Multiple separate inputs
182+
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
183+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
184+
out_shape = self.subnet.compute_output_shape(
185+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
186+
)
187+
self.output_projector.build(out_shape)
176188

177189
# Choose coefficient according to [2] Section 3.3
178190
self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1])
@@ -256,6 +268,35 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
256268
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
257269
return x
258270

271+
def _apply_subnet(
272+
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
273+
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
274+
"""
275+
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
276+
the time `t`, and optional conditions or by returning them separately.
277+
278+
Parameters
279+
----------
280+
x : Tensor
281+
The parameter tensor, typically of shape (..., D), but can vary.
282+
t : Tensor
283+
The time tensor, typically of shape (..., 1).
284+
conditions : Tensor, optional
285+
The optional conditioning tensor (e.g. parameters).
286+
training : bool, optional
287+
The training mode flag, which can be used to control behavior during training.
288+
289+
Returns
290+
-------
291+
Tensor
292+
The output tensor from the subnet.
293+
"""
294+
if self._concatenate_subnet_input:
295+
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
296+
return self.subnet(xtc, training=training)
297+
else:
298+
return self.subnet(x=x, t=t, conditions=conditions, training=training)
299+
259300
def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
260301
"""Compute consistency function.
261302
@@ -271,12 +312,8 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
271312
Whether internal layers (e.g., dropout) should behave in train or inference mode.
272313
"""
273314

274-
if conditions is not None:
275-
xtc = ops.concatenate([x, t, conditions], axis=-1)
276-
else:
277-
xtc = ops.concatenate([x, t], axis=-1)
278-
279-
f = self.output_projector(self.subnet(xtc, training=training))
315+
subnet_out = self._apply_subnet(x, t, conditions, training=training)
316+
f = self.output_projector(subnet_out)
280317

281318
# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
282319
# Thus, we can do a cross product with the time vector which is (batch_size, 1) for
@@ -316,8 +353,8 @@ def compute_metrics(
316353

317354
log_p = ops.log(p)
318355
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
319-
t1 = ops.take(discretized_time, times)[..., None]
320-
t2 = ops.take(discretized_time, times + 1)[..., None]
356+
t1 = expand_right_as(ops.take(discretized_time, times), x)
357+
t2 = expand_right_as(ops.take(discretized_time, times + 1), x)
321358

322359
# generate noise vector
323360
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def __init__(
8585
Additional keyword arguments passed to the noise schedule constructor. Default is None.
8686
integrate_kwargs : dict[str, any], optional
8787
Configuration dictionary for integration during training or inference. Default is None.
88+
concatenate_subnet_input: bool, optional
89+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
90+
into a single vector or passed as separate arguments. If set to False, the subnet
91+
must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
92+
and optional 'conditions'. Default is True.
93+
8894
**kwargs
8995
Additional keyword arguments passed to the base class and internal components.
9096
"""
@@ -116,6 +122,7 @@ def __init__(
116122
if subnet == "mlp":
117123
subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs
118124
self.subnet = find_network(subnet, **subnet_kwargs)
125+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
119126

120127
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
121128

@@ -128,15 +135,23 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
128135
self.output_projector.units = xz_shape[-1]
129136
input_shape = list(xz_shape)
130137

131-
# construct time vector
132-
input_shape[-1] += 1
133-
if conditions_shape is not None:
134-
input_shape[-1] += conditions_shape[-1]
138+
if self._concatenate_subnet_input:
139+
# construct time vector
140+
input_shape[-1] += 1
141+
if conditions_shape is not None:
142+
input_shape[-1] += conditions_shape[-1]
143+
input_shape = tuple(input_shape)
135144

136-
input_shape = tuple(input_shape)
145+
self.subnet.build(input_shape)
146+
out_shape = self.subnet.compute_output_shape(input_shape)
147+
else:
148+
# Multiple separate inputs
149+
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
150+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
151+
out_shape = self.subnet.compute_output_shape(
152+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
153+
)
137154

138-
self.subnet.build(input_shape)
139-
out_shape = self.subnet.compute_output_shape(input_shape)
140155
self.output_projector.build(out_shape)
141156

142157
def get_config(self):
@@ -149,6 +164,8 @@ def get_config(self):
149164
"prediction_type": self._prediction_type,
150165
"loss_type": self._loss_type,
151166
"integrate_kwargs": self.integrate_kwargs,
167+
"concatenate_subnet_input": self._concatenate_subnet_input,
168+
# we do not need to store subnet_kwargs
152169
}
153170
return base_config | serialize(config)
154171

@@ -197,6 +214,35 @@ def convert_prediction_to_x(
197214
return (z + sigma_t**2 * pred) / alpha_t
198215
raise ValueError(f"Unknown prediction type {self._prediction_type}.")
199216

217+
def _apply_subnet(
218+
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False
219+
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
220+
"""
221+
Prepares and passes the input to the subnet either by concatenating the latent variable `xz`,
222+
the signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.
223+
224+
Parameters
225+
----------
226+
xz : Tensor
227+
The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary.
228+
log_snr : Tensor
229+
The log signal-to-noise ratio tensor, typically of shape (..., 1).
230+
conditions : Tensor, optional
231+
The optional conditioning tensor (e.g. parameters).
232+
training : bool, optional
233+
The training mode flag, which can be used to control behavior during training.
234+
235+
Returns
236+
-------
237+
Tensor
238+
The output tensor from the subnet.
239+
"""
240+
if self._concatenate_subnet_input:
241+
xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)
242+
return self.subnet(xtc, training=training)
243+
else:
244+
return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training)
245+
200246
def velocity(
201247
self,
202248
xz: Tensor,
@@ -221,7 +267,7 @@ def velocity(
221267
If True, computes the velocity for the stochastic formulation (SDE).
222268
If False, uses the deterministic formulation (ODE).
223269
conditions : Tensor, optional
224-
Optional conditional inputs to the network, such as conditioning variables
270+
Conditional inputs to the network, such as conditioning variables
225271
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
226272
training : bool, optional
227273
Whether the model is in training mode. Affects behavior of dropout, batch norm,
@@ -238,12 +284,10 @@ def velocity(
238284
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
239285
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
240286

241-
if conditions is None:
242-
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1)
243-
else:
244-
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1)
245-
246-
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
287+
subnet_out = self._apply_subnet(
288+
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
289+
)
290+
pred = self.output_projector(subnet_out, training=training)
247291

248292
x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)
249293

@@ -461,11 +505,10 @@ def compute_metrics(
461505
diffused_x = alpha_t * x + sigma_t * eps_t
462506

463507
# calculate output of the network
464-
if conditions is None:
465-
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1)
466-
else:
467-
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1)
468-
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
508+
subnet_out = self._apply_subnet(
509+
diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
510+
)
511+
pred = self.output_projector(subnet_out, training=training)
469512

470513
x_pred = self.convert_prediction_to_x(
471514
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
layer_kwargs,
1313
optimal_transport,
1414
weighted_mean,
15+
tensor_utils,
1516
)
1617
from bayesflow.utils.serialization import serialize, deserialize, serializable
1718
from ..inference_network import InferenceNetwork
@@ -90,6 +91,11 @@ def __init__(
9091
Additional keyword arguments for configuring optimal transport. Default is None.
9192
subnet_kwargs: dict[str, any], optional, deprecated
9293
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
94+
concatenate_subnet_input: bool, optional
95+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
96+
into a single vector or passed as separate arguments. If set to False, the subnet
97+
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
98+
and optional 'conditions'. Default is True.
9399
**kwargs
94100
Additional keyword arguments passed to the subnet and other components.
95101
"""
@@ -107,6 +113,7 @@ def __init__(
107113
subnet_kwargs = subnet_kwargs or {}
108114
if subnet == "mlp":
109115
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG | subnet_kwargs
116+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
110117

111118
self.subnet = find_network(subnet, **subnet_kwargs)
112119
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
@@ -121,16 +128,25 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
121128

122129
self.output_projector.units = xz_shape[-1]
123130

124-
# account for concatenating the time and conditions
125131
input_shape = list(xz_shape)
126-
input_shape[-1] += 1
127-
if conditions_shape is not None:
128-
input_shape[-1] += conditions_shape[-1]
129-
input_shape = tuple(input_shape)
132+
if self._concatenate_subnet_input:
133+
# construct time vector
134+
input_shape[-1] += 1
135+
if conditions_shape is not None:
136+
input_shape[-1] += conditions_shape[-1]
137+
input_shape = tuple(input_shape)
138+
139+
self.subnet.build(input_shape)
140+
out_shape = self.subnet.compute_output_shape(input_shape)
141+
else:
142+
# Multiple separate inputs
143+
time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature
144+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
145+
out_shape = self.subnet.compute_output_shape(
146+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
147+
)
130148

131-
self.subnet.build(input_shape)
132-
input_shape = self.subnet.compute_output_shape(input_shape)
133-
self.output_projector.build(input_shape)
149+
self.output_projector.build(out_shape)
134150

135151
@classmethod
136152
def from_config(cls, config, custom_objects=None):
@@ -147,22 +163,50 @@ def get_config(self):
147163
"loss_fn": self.loss_fn,
148164
"integrate_kwargs": self.integrate_kwargs,
149165
"optimal_transport_kwargs": self.optimal_transport_kwargs,
166+
"concatenate_subnet_input": self._concatenate_subnet_input,
150167
# we do not need to store subnet_kwargs
151168
}
152169

153170
return base_config | serialize(config)
154171

172+
def _apply_subnet(
173+
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
174+
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
175+
"""
176+
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
177+
the time `t`, and optional conditions or by returning them separately.
178+
179+
Parameters
180+
----------
181+
x : Tensor
182+
The parameter tensor, typically of shape (..., D), but can vary.
183+
t : Tensor
184+
The time tensor, typically of shape (..., 1).
185+
conditions : Tensor, optional
186+
The optional conditioning tensor (e.g. parameters).
187+
training : bool, optional
188+
The training mode flag, which can be used to control behavior during training.
189+
190+
Returns
191+
-------
192+
Tensor
193+
The output tensor from the subnet.
194+
"""
195+
if self._concatenate_subnet_input:
196+
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
197+
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
198+
return self.subnet(xtc, training=training)
199+
else:
200+
if training is False:
201+
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
202+
return self.subnet(x=x, t=t, conditions=conditions, training=training)
203+
155204
def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
156205
time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz))
157206
time = expand_right_as(time, xz)
158-
time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,))
159-
160-
if conditions is None:
161-
xtc = keras.ops.concatenate([xz, time], axis=-1)
162-
else:
163-
xtc = keras.ops.concatenate([xz, time, conditions], axis=-1)
164207

165-
return self.output_projector(self.subnet(xtc, training=training), training=training)
208+
subnet_out = self._apply_subnet(xz, time, conditions, training=training)
209+
return self.output_projector(subnet_out, training=training)
166210

167211
def _velocity_trace(
168212
self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False

0 commit comments

Comments
 (0)