Skip to content

Commit a9e9664

Browse files
committed
Refactor optimal transport logic in flow matching.
Updated the optimal transport implementation to prioritize resampling x1 due to potential data noise and outliers, ensuring better stability across batches. Adjusted conditions to align properly with assignments returned by the optimal transport function. These changes improve robustness and clarity in the resampling process.
1 parent 41690de commit a9e9664

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,20 @@ def compute_metrics(
256256
x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1])
257257

258258
if self.use_optimal_transport:
259-
x1, x0, conditions = optimal_transport(
260-
x1, x0, conditions, seed=self.seed_generator, **self.optimal_transport_kwargs
259+
# we must choose between resampling x0 or x1
260+
# since the data is possibly noisy and may contain outliers, it is better
261+
# to possibly drop some samples from x1 than from x0
262+
# in the marginal over multiple batches, this is not a problem
263+
x0, x1, assignments = optimal_transport(
264+
x0,
265+
x1,
266+
seed=self.seed_generator,
267+
**self.optimal_transport_kwargs,
268+
return_assignments=True,
261269
)
270+
if conditions is not None:
271+
# conditions must be resampled along with x1
272+
conditions = conditions[assignments]
262273

263274
t = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
264275
t = expand_right_as(t, x0)

0 commit comments

Comments
 (0)