Skip to content

Commit afef095

Browse files
authored
Add classes and transforms to simplify multimodal training (#473)
* Add classes and transforms to simplify multimodal training - Add class `MultimodalSummaryNetwork` to combine multiple summary networks, each for one modality. - Add transforms `Group` and `Ungroup`, to gather the multimodal inputs in one variable (usually "summary_variables") - Add tests for new behavior * [no ci] add tutorial notebook for multimodal data * [no ci] add missing training argument * rename MultimodalSummaryNetwork to FusionNetwork * [no ci] clarify that the network implements late fusion
1 parent 35cd671 commit afef095

File tree

12 files changed

+1067
-0
lines changed

12 files changed

+1067
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Drop,
1515
ExpandDims,
1616
FilterTransform,
17+
Group,
1718
Keep,
1819
Log,
1920
MapTransform,
@@ -25,6 +26,7 @@
2526
Standardize,
2627
ToArray,
2728
Transform,
29+
Ungroup,
2830
RandomSubsample,
2931
Take,
3032
)
@@ -600,6 +602,52 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple):
600602
self.transforms.append(transform)
601603
return self
602604

605+
def group(self, keys: Sequence[str], into: str, *, prefix: str = ""):
606+
"""Append a :py:class:`~transforms.Group` transform to the adapter.
607+
608+
Groups the given variables as a dictionary in the key `into`. As most transforms do
609+
not support nested structures, this should usually be the last transform in the adapter.
610+
611+
Parameters
612+
----------
613+
keys : Sequence of str
614+
The names of the variables to group together.
615+
into : str
616+
The name of the variable to store the grouped variables in.
617+
prefix : str, optional
618+
An optional common prefix of the variable names before grouping, which will be removed after grouping.
619+
620+
Raises
621+
------
622+
ValueError
623+
If a prefix is specified, but a provided key does not start with the prefix.
624+
"""
625+
if isinstance(keys, str):
626+
keys = [keys]
627+
628+
transform = Group(keys=keys, into=into, prefix=prefix)
629+
self.transforms.append(transform)
630+
return self
631+
632+
def ungroup(self, key: str, *, prefix: str = ""):
633+
"""Append an :py:class:`~transforms.Ungroup` transform to the adapter.
634+
635+
Ungroups the the variables in `key` from a dictionary into individual entries. Most transforms do
636+
not support nested structures, so this can be used to flatten a nested structure.
637+
The nesting can be re-established after the transforms using the :py:meth:`group` method.
638+
639+
Parameters
640+
----------
641+
key : str
642+
The name of the variable to ungroup. The corresponding variable has to be a dictionary.
643+
prefix : str, optional
644+
An optional common prefix that will be added to the ungrouped variable names. This can be necessary
645+
to avoid duplicate names.
646+
"""
647+
transform = Ungroup(key=key, prefix=prefix)
648+
self.transforms.append(transform)
649+
return self
650+
603651
def keep(self, keys: str | Sequence[str]):
604652
"""Append a :py:class:`~transforms.Keep` transform to the adapter.
605653

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .elementwise_transform import ElementwiseTransform
99
from .expand_dims import ExpandDims
1010
from .filter_transform import FilterTransform
11+
from .group import Group
1112
from .keep import Keep
1213
from .log import Log
1314
from .map_transform import MapTransform
@@ -25,6 +26,7 @@
2526
from .transform import Transform
2627
from .random_subsample import RandomSubsample
2728
from .take import Take
29+
from .ungroup import Ungroup
2830

2931
from ...utils._docs import _add_imports_to_all
3032

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from collections.abc import Sequence
2+
from .transform import Transform
3+
from bayesflow.utils.serialization import serializable, serialize
4+
5+
6+
@serializable("bayesflow.adapters")
7+
class Group(Transform):
8+
def __init__(self, keys: Sequence[str], into: str, prefix: str = ""):
9+
"""Groups the given variables as a dictionary in the key `into`. As most transforms do
10+
not support nested structures, this should usually be the last transform.
11+
12+
Parameters
13+
----------
14+
keys : Sequence of str
15+
The names of the variables to group together.
16+
into : str
17+
The name of the variable to store the grouped variables in.
18+
prefix : str, optional
19+
A common prefix of the ungrouped variable names, which will be removed after grouping.
20+
21+
Raises
22+
------
23+
ValueError
24+
If a prefix is specified, but a provided key does not start with the prefix.
25+
"""
26+
super().__init__()
27+
self.keys = keys
28+
self.into = into
29+
self.prefix = prefix
30+
for key in keys:
31+
if not key.startswith(prefix):
32+
raise ValueError(f"If prefix is specified, all keys have to start with prefix. Found '{key}'.")
33+
34+
def get_config(self) -> dict:
35+
return serialize({"keys": self.keys, "into": self.into, "prefix": self.prefix})
36+
37+
def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
38+
data = data.copy()
39+
40+
data[self.into] = data.get(self.into, {})
41+
for key in self.keys:
42+
if key not in data:
43+
if strict:
44+
raise KeyError(f"Missing key: {key!r}")
45+
else:
46+
data[self.into][key[len(self.prefix) :]] = data.pop(key)
47+
48+
return data
49+
50+
def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]:
51+
data = data.copy()
52+
53+
if strict and self.into not in data:
54+
raise KeyError(f"Missing key: {self.into!r}")
55+
elif self.into not in data:
56+
return data
57+
58+
for key in self.keys:
59+
internal_key = key[len(self.prefix) :]
60+
if internal_key not in data[self.into]:
61+
if strict:
62+
raise KeyError(f"Missing key: {internal_key!r}")
63+
else:
64+
data[key] = data[self.into].pop(internal_key)
65+
66+
if len(data[self.into]) == 0:
67+
del data[self.into]
68+
69+
return data
70+
71+
def extra_repr(self) -> str:
72+
return f"{self.keys!r} -> {self.into!r}"
73+
74+
def log_det_jac(
75+
self,
76+
data: dict[str, any],
77+
log_det_jac: dict[str, any],
78+
inverse: bool = False,
79+
**kwargs,
80+
):
81+
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from .transform import Transform
2+
from bayesflow.utils.serialization import deserialize, serializable, serialize
3+
4+
5+
@serializable("bayesflow.adapters")
6+
class Ungroup(Transform):
7+
def __init__(self, key: str, prefix: str = ""):
8+
"""
9+
Ungroups the the variables in `key` from a dictionary into individual entries. Most transforms do
10+
not support nested structures, so this can be used to flatten a nested structure.
11+
It can later on be reassembled using the :py:class:`bayesflow.adapters.transforms.Group` transform.
12+
13+
Parameters
14+
----------
15+
key : str
16+
The name of the variable to ungroup. The variable has to be a dictionary.
17+
prefix : str, optional
18+
An optional common prefix that will be added to the ungrouped variable names. This can be necessary
19+
to avoid duplicate names.
20+
"""
21+
super().__init__()
22+
self.key = key
23+
self.prefix = prefix
24+
self._ungrouped = None
25+
26+
def get_config(self) -> dict:
27+
return serialize({"key": self.key, "prefix": self.prefix, "_ungrouped": self._ungrouped})
28+
29+
@classmethod
30+
def from_config(cls, config: dict, custom_objects=None):
31+
config = deserialize(config, custom_objects)
32+
_ungrouped = config.pop("_ungrouped")
33+
transform = cls(**config)
34+
transform._ungrouped = _ungrouped
35+
return transform
36+
37+
def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
38+
data = data.copy()
39+
40+
if self.key not in data and strict:
41+
raise KeyError(f"Missing key: {self.key!r}")
42+
elif self.key not in data:
43+
return data
44+
45+
ungrouped = []
46+
for k, v in data.pop(self.key).items():
47+
new_key = f"{self.prefix}{k}"
48+
if new_key in data:
49+
raise ValueError(
50+
f"Encountered duplicate key during ungrouping: '{new_key}'."
51+
" Use `prefix` to specify a unique prefix that is added to the key"
52+
)
53+
ungrouped.append(new_key)
54+
data[new_key] = v
55+
if self._ungrouped is None:
56+
self._ungrouped = sorted(ungrouped)
57+
else:
58+
self._ungrouped = sorted(list(set(self._ungrouped + ungrouped)))
59+
60+
return data
61+
62+
def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]:
63+
data = data.copy()
64+
65+
data[self.key] = {}
66+
for key in self._ungrouped:
67+
if key not in data:
68+
if strict:
69+
raise KeyError(f"Missing key: {key!r}")
70+
else:
71+
recovered_key = key[len(self.prefix) :]
72+
data[self.key][recovered_key] = data.pop(key)
73+
74+
return data
75+
76+
def log_det_jac(
77+
self,
78+
data: dict[str, any],
79+
log_det_jac: dict[str, any],
80+
inverse: bool = False,
81+
**kwargs,
82+
):
83+
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False)

bayesflow/networks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .inference_network import InferenceNetwork
1212
from .point_inference_network import PointInferenceNetwork
1313
from .mlp import MLP
14+
from .fusion_network import FusionNetwork
1415
from .summary_network import SummaryNetwork
1516
from .time_series_network import TimeSeriesNetwork
1617
from .transformers import SetTransformer, TimeSeriesTransformer, FusionTransformer
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .fusion_network import FusionNetwork
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from collections.abc import Mapping
2+
from ..summary_network import SummaryNetwork
3+
from bayesflow.utils.serialization import deserialize, serializable, serialize
4+
from bayesflow.types import Tensor, Shape
5+
import keras
6+
from keras import ops
7+
8+
9+
@serializable("bayesflow.networks")
10+
class FusionNetwork(SummaryNetwork):
11+
def __init__(
12+
self,
13+
backbones: Mapping[str, keras.Layer],
14+
head: keras.Layer | None = None,
15+
**kwargs,
16+
):
17+
"""(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from multi-modal data.
18+
19+
Networks and inputs are passed as dictionaries with corresponding keys, so that each input is processed
20+
by the correct summary network. This means the "summary_variables" entry to the approximator has to be
21+
a dictionary, which can be achieved using the :py:meth:`bayesflow.adapters.Adapter.group` method.
22+
23+
This network implements _late_ fusion. The output of the individual summary networks is concatenated, and
24+
can be further processed by another neural network (`head`).
25+
26+
Parameters
27+
----------
28+
backbones : dict
29+
A dictionary with names of inputs as keys and corresponding summary networks as values.
30+
head : keras.Layer, optional
31+
A network to further process the concatenated outputs of the summary networks. By default,
32+
the concatenated outputs are returned without further processing.
33+
**kwargs
34+
Additional keyword arguments that are passed to the :py:class:`~bayesflow.networks.SummaryNetwork`
35+
base class.
36+
"""
37+
super().__init__(**kwargs)
38+
self.backbones = backbones
39+
self.head = head
40+
self._ordered_keys = sorted(list(self.backbones.keys()))
41+
42+
def build(self, inputs_shape: Mapping[str, Shape]):
43+
if self.built:
44+
return
45+
output_shapes = []
46+
for k, shape in inputs_shape.items():
47+
if not self.backbones[k].built:
48+
self.backbones[k].build(shape)
49+
output_shapes.append(self.backbones[k].compute_output_shape(shape))
50+
if self.head and not self.head.built:
51+
fusion_input_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes))
52+
self.head.build(fusion_input_shape)
53+
self.built = True
54+
55+
def compute_output_shape(self, inputs_shape: Mapping[str, Shape]):
56+
output_shapes = []
57+
for k, shape in inputs_shape.items():
58+
output_shapes.append(self.backbones[k].compute_output_shape(shape))
59+
output_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes))
60+
if self.head:
61+
output_shape = self.head.compute_output_shape(output_shape)
62+
return output_shape
63+
64+
def call(self, inputs: Mapping[str, Tensor], training=False):
65+
"""
66+
Parameters
67+
----------
68+
inputs : dict[str, Tensor]
69+
Each value in the dictionary is the input to the summary network with the corresponding key.
70+
training : bool, optional
71+
Whether the model is in training mode, affecting layers like dropout and
72+
batch normalization. Default is False.
73+
"""
74+
outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys]
75+
outputs = ops.concatenate(outputs, axis=-1)
76+
if self.head is None:
77+
return outputs
78+
return self.head(outputs, training=training)
79+
80+
def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training", **kwargs) -> dict[str, Tensor]:
81+
"""
82+
Parameters
83+
----------
84+
inputs : dict[str, Tensor]
85+
Each value in the dictionary is the input to the summary network with the corresponding key.
86+
stage : bool, optional
87+
Whether the model is in training mode, affecting layers like dropout and
88+
batch normalization. Default is False.
89+
**kwargs
90+
Additional keyword arguments.
91+
"""
92+
metrics = {"loss": [], "outputs": []}
93+
94+
for k in self._ordered_keys:
95+
if isinstance(self.backbones[k], SummaryNetwork):
96+
metrics_k = self.backbones[k].compute_metrics(inputs[k], stage=stage, **kwargs)
97+
metrics["outputs"].append(metrics_k["outputs"])
98+
if "loss" in metrics_k:
99+
metrics["loss"].append(metrics_k["loss"])
100+
else:
101+
metrics["outputs"].append(self.backbones[k](inputs[k], training=stage == "training"))
102+
if len(metrics["loss"]) == 0:
103+
del metrics["loss"]
104+
else:
105+
metrics["loss"] = ops.sum(metrics["loss"])
106+
metrics["outputs"] = ops.concatenate(metrics["outputs"], axis=-1)
107+
if self.head is not None:
108+
metrics["outputs"] = self.head(metrics["outputs"], training=stage == "training")
109+
110+
return metrics
111+
112+
def get_config(self) -> dict:
113+
base_config = super().get_config()
114+
config = {
115+
"backbones": self.backbones,
116+
"head": self.head,
117+
}
118+
return base_config | serialize(config)
119+
120+
@classmethod
121+
def from_config(cls, config: dict, custom_objects=None):
122+
config = deserialize(config, custom_objects=custom_objects)
123+
return cls(**config)

0 commit comments

Comments
 (0)