66from bayesflow .utils .decorators import allow_batch_size
77
88from bayesflow .utils import numpy_utils as npu
9+ from bayesflow .utils import logging
910
1011from types import FunctionType
12+ from typing import Literal
1113
1214from .simulator import Simulator
1315from .lambda_simulator import LambdaSimulator
@@ -22,6 +24,8 @@ def __init__(
2224 p : Sequence [float ] = None ,
2325 logits : Sequence [float ] = None ,
2426 use_mixed_batches : bool = True ,
27+ key_conflicts : Literal ["drop" , "fill" , "error" ] = "drop" ,
28+ fill_value : float = np .nan ,
2529 shared_simulator : Simulator | Callable [[Sequence [int ]], dict [str , any ]] = None ,
2630 ):
2731 """
@@ -38,11 +42,21 @@ def __init__(
3842 A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
3943 If neither `p` nor `logits` is provided, defaults to uniform logits.
4044 use_mixed_batches : bool, optional
41- If True, samples in a batch are drawn from different models. If False, the entire batch
42- is drawn from a single model chosen according to the model probabilities. Default is True.
45+ Whether to draw samples in a batch from different models.
46+
47+ - If True (default), each sample in a batch may come from a different model.
48+ - If False, the entire batch is drawn from a single model, selected according to model probabilities.
49+ key_conflicts : str, optional
50+ Policy for handling keys that are missing in the output of some models, when using mixed batches.
51+
52+ - "drop" (default): Drop conflicting keys from the batch output.
53+ - "fill": Fill missing keys with the specified value.
54+ - "error": An error is raised when key conflicts are detected.
55+ fill_value : float, optional
56+ If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument.
4357 shared_simulator : Simulator or Callable, optional
4458 A shared simulator whose outputs are passed to all model simulators. If a function is
45- provided, it is wrapped in a ` LambdaSimulator` with batching enabled.
59+ provided, it is wrapped in a :py:class:`~bayesflow.simulators. LambdaSimulator` with batching enabled.
4660 """
4761 self .simulators = simulators
4862
@@ -68,6 +82,9 @@ def __init__(
6882
6983 self .logits = logits
7084 self .use_mixed_batches = use_mixed_batches
85+ self .key_conflicts = key_conflicts
86+ self .fill_value = fill_value
87+ self ._key_conflicts_warning = True
7188
7289 @allow_batch_size
7390 def sample (self , batch_shape : Shape , ** kwargs ) -> dict [str , np .ndarray ]:
@@ -105,6 +122,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
105122 sims = [
106123 simulator .sample (n , ** (kwargs | data )) for simulator , n in zip (self .simulators , model_counts ) if n > 0
107124 ]
125+ sims = self ._handle_key_conflicts (sims , model_counts )
108126 sims = tree_concatenate (sims , numpy = True )
109127 data |= sims
110128
@@ -118,3 +136,58 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
118136 model_indices = npu .one_hot (np .full (batch_shape , model_index , dtype = "int32" ), num_models )
119137
120138 return data | {"model_indices" : model_indices }
139+
140+ def _handle_key_conflicts (self , sims , batch_sizes ):
141+ batch_sizes = [b for b in batch_sizes if b > 0 ]
142+
143+ keys , all_keys , common_keys , missing_keys = self ._determine_key_conflicts (sims = sims )
144+
145+ # all sims have the same keys
146+ if all_keys == common_keys :
147+ return sims
148+
149+ if self .key_conflicts == "drop" :
150+ sims = [{k : v for k , v in sim .items () if k in common_keys } for sim in sims ]
151+ return sims
152+ elif self .key_conflicts == "fill" :
153+ combined_sims = {}
154+ for sim in sims :
155+ combined_sims = combined_sims | sim
156+ for i , sim in enumerate (sims ):
157+ for missing_key in missing_keys [i ]:
158+ shape = combined_sims [missing_key ].shape
159+ shape = list (shape )
160+ shape [0 ] = batch_sizes [i ]
161+ sim [missing_key ] = np .full (shape = shape , fill_value = self .fill_value )
162+ return sims
163+ elif self .key_conflicts == "error" :
164+ raise ValueError (
165+ "Different simulators provide outputs with different keys, cannot combine them into one batch."
166+ )
167+
168+ def _determine_key_conflicts (self , sims ):
169+ keys = [set (sim .keys ()) for sim in sims ]
170+ all_keys = set .union (* keys )
171+ common_keys = set .intersection (* keys )
172+ missing_keys = [all_keys - k for k in keys ]
173+
174+ if all_keys == common_keys :
175+ return keys , all_keys , common_keys , missing_keys
176+
177+ if self ._key_conflicts_warning :
178+ # issue warning only once
179+ self ._key_conflicts_warning = False
180+
181+ if self .key_conflicts == "drop" :
182+ logging .info (
183+ f"Incompatible simulator output. \
184+ The following keys will be dropped: { ', ' .join (sorted (all_keys - common_keys ))} ."
185+ )
186+ elif self .key_conflicts == "fill" :
187+ logging .info (
188+ f"Incompatible simulator output. \
189+ Attempting to replace keys: { ', ' .join (sorted (all_keys - common_keys ))} , where missing, \
190+ with value { self .fill_value } ."
191+ )
192+
193+ return keys , all_keys , common_keys , missing_keys
0 commit comments