Skip to content

Commit 35cd671

Browse files
Kucharssimvpratz
andauthored
ModelComparisonSimulator: handle different outputs from individual simulators (#452)
Adds option to drop, fill or error when different keys are encountered in the outputs of different simulators. Fixes #441. --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de>
1 parent e13f7c4 commit 35cd671

File tree

3 files changed

+148
-3
lines changed

3 files changed

+148
-3
lines changed

bayesflow/simulators/model_comparison_simulator.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from bayesflow.utils.decorators import allow_batch_size
77

88
from bayesflow.utils import numpy_utils as npu
9+
from bayesflow.utils import logging
910

1011
from types import FunctionType
12+
from typing import Literal
1113

1214
from .simulator import Simulator
1315
from .lambda_simulator import LambdaSimulator
@@ -22,6 +24,8 @@ def __init__(
2224
p: Sequence[float] = None,
2325
logits: Sequence[float] = None,
2426
use_mixed_batches: bool = True,
27+
key_conflicts: Literal["drop", "fill", "error"] = "drop",
28+
fill_value: float = np.nan,
2529
shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None,
2630
):
2731
"""
@@ -38,11 +42,21 @@ def __init__(
3842
A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
3943
If neither `p` nor `logits` is provided, defaults to uniform logits.
4044
use_mixed_batches : bool, optional
41-
If True, samples in a batch are drawn from different models. If False, the entire batch
42-
is drawn from a single model chosen according to the model probabilities. Default is True.
45+
Whether to draw samples in a batch from different models.
46+
47+
- If True (default), each sample in a batch may come from a different model.
48+
- If False, the entire batch is drawn from a single model, selected according to model probabilities.
49+
key_conflicts : str, optional
50+
Policy for handling keys that are missing in the output of some models, when using mixed batches.
51+
52+
- "drop" (default): Drop conflicting keys from the batch output.
53+
- "fill": Fill missing keys with the specified value.
54+
- "error": An error is raised when key conflicts are detected.
55+
fill_value : float, optional
56+
If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument.
4357
shared_simulator : Simulator or Callable, optional
4458
A shared simulator whose outputs are passed to all model simulators. If a function is
45-
provided, it is wrapped in a `LambdaSimulator` with batching enabled.
59+
provided, it is wrapped in a :py:class:`~bayesflow.simulators.LambdaSimulator` with batching enabled.
4660
"""
4761
self.simulators = simulators
4862

@@ -68,6 +82,9 @@ def __init__(
6882

6983
self.logits = logits
7084
self.use_mixed_batches = use_mixed_batches
85+
self.key_conflicts = key_conflicts
86+
self.fill_value = fill_value
87+
self._key_conflicts_warning = True
7188

7289
@allow_batch_size
7390
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
@@ -105,6 +122,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
105122
sims = [
106123
simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0
107124
]
125+
sims = self._handle_key_conflicts(sims, model_counts)
108126
sims = tree_concatenate(sims, numpy=True)
109127
data |= sims
110128

@@ -118,3 +136,58 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
118136
model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models)
119137

120138
return data | {"model_indices": model_indices}
139+
140+
def _handle_key_conflicts(self, sims, batch_sizes):
141+
batch_sizes = [b for b in batch_sizes if b > 0]
142+
143+
keys, all_keys, common_keys, missing_keys = self._determine_key_conflicts(sims=sims)
144+
145+
# all sims have the same keys
146+
if all_keys == common_keys:
147+
return sims
148+
149+
if self.key_conflicts == "drop":
150+
sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims]
151+
return sims
152+
elif self.key_conflicts == "fill":
153+
combined_sims = {}
154+
for sim in sims:
155+
combined_sims = combined_sims | sim
156+
for i, sim in enumerate(sims):
157+
for missing_key in missing_keys[i]:
158+
shape = combined_sims[missing_key].shape
159+
shape = list(shape)
160+
shape[0] = batch_sizes[i]
161+
sim[missing_key] = np.full(shape=shape, fill_value=self.fill_value)
162+
return sims
163+
elif self.key_conflicts == "error":
164+
raise ValueError(
165+
"Different simulators provide outputs with different keys, cannot combine them into one batch."
166+
)
167+
168+
def _determine_key_conflicts(self, sims):
169+
keys = [set(sim.keys()) for sim in sims]
170+
all_keys = set.union(*keys)
171+
common_keys = set.intersection(*keys)
172+
missing_keys = [all_keys - k for k in keys]
173+
174+
if all_keys == common_keys:
175+
return keys, all_keys, common_keys, missing_keys
176+
177+
if self._key_conflicts_warning:
178+
# issue warning only once
179+
self._key_conflicts_warning = False
180+
181+
if self.key_conflicts == "drop":
182+
logging.info(
183+
f"Incompatible simulator output. \
184+
The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}."
185+
)
186+
elif self.key_conflicts == "fill":
187+
logging.info(
188+
f"Incompatible simulator output. \
189+
Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \
190+
with value {self.fill_value}."
191+
)
192+
193+
return keys, all_keys, common_keys, missing_keys

tests/test_simulators/conftest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,56 @@ def likelihood(mu, n):
167167
return make_simulator([prior, likelihood], meta_fn=context)
168168

169169

170+
@pytest.fixture()
171+
def multimodel():
172+
from bayesflow.simulators import make_simulator, ModelComparisonSimulator
173+
174+
def context(batch_size):
175+
return dict(n=np.random.randint(10, 100))
176+
177+
def prior_0():
178+
return dict(mu=0)
179+
180+
def prior_1():
181+
return dict(mu=np.random.standard_normal())
182+
183+
def likelihood(n, mu):
184+
return dict(y=np.random.normal(mu, 1, n))
185+
186+
simulator_0 = make_simulator([prior_0, likelihood])
187+
simulator_1 = make_simulator([prior_1, likelihood])
188+
189+
simulator = ModelComparisonSimulator(simulators=[simulator_0, simulator_1], shared_simulator=context)
190+
191+
return simulator
192+
193+
194+
@pytest.fixture(params=["drop", "fill", "error"])
195+
def multimodel_key_conflicts(request):
196+
from bayesflow.simulators import make_simulator, ModelComparisonSimulator
197+
198+
rng = np.random.default_rng()
199+
200+
def prior_1():
201+
return dict(w=rng.uniform())
202+
203+
def prior_2():
204+
return dict(c=rng.uniform())
205+
206+
def model_1(w):
207+
return dict(x=w)
208+
209+
def model_2(c):
210+
return dict(x=c)
211+
212+
simulator_1 = make_simulator([prior_1, model_1])
213+
simulator_2 = make_simulator([prior_2, model_2])
214+
215+
simulator = ModelComparisonSimulator(simulators=[simulator_1, simulator_2], key_conflicts=request.param)
216+
217+
return simulator
218+
219+
170220
@pytest.fixture()
171221
def fixed_n():
172222
return 5

tests/test_simulators/test_simulators.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import keras
23
import numpy as np
34

@@ -47,3 +48,24 @@ def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
4748
assert samples["mu"].shape == (batch_size, 1)
4849
assert np.all(samples["mu"] == fixed_mu)
4950
assert samples["y"].shape == (batch_size, fixed_n)
51+
52+
53+
def test_multimodel_sample(multimodel, batch_size):
54+
samples = multimodel.sample(batch_size)
55+
56+
assert set(samples) == {"n", "mu", "y", "model_indices"}
57+
assert samples["mu"].shape == (batch_size, 1)
58+
assert samples["y"].shape == (batch_size, samples["n"])
59+
60+
61+
def test_multimodel_key_conflicts_sample(multimodel_key_conflicts, batch_size):
62+
if multimodel_key_conflicts.key_conflicts == "drop":
63+
samples = multimodel_key_conflicts.sample(batch_size)
64+
assert set(samples) == {"x", "model_indices"}
65+
elif multimodel_key_conflicts.key_conflicts == "fill":
66+
samples = multimodel_key_conflicts.sample(batch_size)
67+
assert set(samples) == {"x", "model_indices", "c", "w"}
68+
assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size
69+
elif multimodel_key_conflicts.key_conflicts == "error":
70+
with pytest.raises(ValueError):
71+
samples = multimodel_key_conflicts.sample(batch_size)

0 commit comments

Comments
 (0)