Skip to content

Commit 41690de

Browse files
committed
Refactor optimal transport utilities and add log-Sinkhorn
Removed unimplemented Hungarian algorithm and streamline Sinkhorn implementation. Introduced log-Sinkhorn for numerical stability and modularized cost computation (e.g., Euclidean distance). Updated the interface for improved usability and scalability.
1 parent 2fac738 commit 41690de

File tree

6 files changed

+160
-256
lines changed

6 files changed

+160
-256
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .optimal_transport import optimal_transport
2-
from .sinkhorn import sinkhorn, sinkhorn_indices, sinkhorn_plan
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import keras
2+
3+
4+
def euclidean(x1, x2):
5+
# TODO: rename and move this function
6+
result = x1[:, None] - x2[None, :]
7+
shape = list(keras.ops.shape(result))
8+
shape[2:] = [-1]
9+
result = keras.ops.reshape(result, shape)
10+
result = keras.ops.norm(result, ord=2, axis=-1)
11+
return result

bayesflow/utils/optimal_transport/hungarian.py

Lines changed: 0 additions & 5 deletions
This file was deleted.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import keras
2+
3+
from .. import logging
4+
from ..tensor_utils import is_symbolic_tensor
5+
6+
from .euclidean import euclidean
7+
8+
9+
def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
10+
"""
11+
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
12+
Significantly slower than the unstabilized version, so use only when you need numerical stability.
13+
"""
14+
log_plan = log_sinkhorn_plan(x1, x2, **kwargs)
15+
assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed)
16+
assignments = keras.ops.squeeze(assignments, axis=1)
17+
18+
return assignments
19+
20+
21+
def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None):
22+
"""
23+
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`.
24+
Significantly slower than the unstabilized version, so use only when you need numerical stability.
25+
"""
26+
cost = euclidean(x1, x2)
27+
28+
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
29+
30+
if is_symbolic_tensor(log_plan):
31+
return log_plan
32+
33+
def contains_nans(plan):
34+
return keras.ops.any(keras.ops.isnan(plan))
35+
36+
def is_converged(plan):
37+
# for convergence, the plan should be doubly stochastic
38+
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol))
39+
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol))
40+
return conv0 & conv1
41+
42+
def cond(_, plan):
43+
# break the while loop if the plan contains nans or is converged
44+
return ~(contains_nans(plan) | is_converged(plan))
45+
46+
def body(steps, plan):
47+
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
48+
plan = keras.ops.log_softmax(plan, axis=0)
49+
plan = keras.ops.log_softmax(plan, axis=1)
50+
51+
return steps + 1, plan
52+
53+
steps = 0
54+
steps, log_plan = keras.ops.while_loop(cond, body, (steps, log_plan), maximum_iterations=max_steps)
55+
56+
def do_nothing():
57+
pass
58+
59+
def log_steps():
60+
msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
61+
62+
logging.info(msg, steps)
63+
64+
def warn_convergence():
65+
marginals = keras.ops.logsumexp(log_plan, axis=0)
66+
deviations = keras.ops.abs(marginals)
67+
badness = 100.0 * keras.ops.exp(keras.ops.max(deviations))
68+
69+
msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
70+
71+
logging.warning(msg, max_steps, badness)
72+
73+
def warn_nans():
74+
msg = "Log-Sinkhorn-Knopp produced NaNs."
75+
logging.warning(msg)
76+
77+
keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing)
78+
keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence)
79+
80+
return log_plan
Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from bayesflow.types import Tensor
2-
3-
from .hungarian import hungarian
1+
from .log_sinkhorn import log_sinkhorn
42
from .sinkhorn import sinkhorn
53

4+
methods = {
5+
"sinkhorn": sinkhorn,
6+
"sinkhorn_knopp": sinkhorn,
7+
"log_sinkhorn": log_sinkhorn,
8+
"log_sinkhorn_knopp": log_sinkhorn,
9+
}
10+
611

7-
def optimal_transport(
8-
x1: Tensor, x2: Tensor, *aux: Tensor, method: str = "sinkhorn_knopp", **kwargs
9-
) -> (Tensor, Tensor):
12+
def optimal_transport(x1, x2, method="log_sinkhorn", return_assignments=False, **kwargs):
1013
"""Matches elements from x2 onto x1, such that the transport cost between them is minimized, according to the method
1114
and cost matrix used.
1215
@@ -22,27 +25,21 @@ def optimal_transport(
2225
:param x2: Tensor of shape (m, ...)
2326
Samples from the second distribution.
2427
25-
:param aux: Tensors of shape (n, ...)
26-
Auxiliary tensors to be permuted along with x1.
27-
Note that x2 is never permuted for all currently available methods.
28-
2928
:param method: Method used to compute the transport cost.
30-
Default: 'sinkhorn_knopp'
29+
Default: 'log_sinkhorn'
3130
32-
:param kwargs: Additional keyword arguments passed to the optimization method.
31+
:param return_assignments: Whether to return the assignment indices.
32+
Default: False
33+
34+
:param kwargs: Additional keyword arguments that are passed to the optimization method.
3335
3436
:return: Tensors of shapes (n, ...) and (m, ...)
3537
x1 and x2 in optimal transport permutation order.
3638
"""
37-
methods = {
38-
"hungarian": hungarian,
39-
"sinkhorn": sinkhorn,
40-
"sinkhorn_knopp": sinkhorn,
41-
}
42-
43-
method = method.lower()
39+
assignments = methods[method.lower()](x1, x2, **kwargs)
40+
x2 = x2[assignments]
4441

45-
if method not in methods:
46-
raise ValueError(f"Unsupported method name: '{method}'.")
42+
if return_assignments:
43+
return x1, x2, assignments
4744

48-
return methods[method](x1, x2, *aux, **kwargs)
45+
return x1, x2

0 commit comments

Comments
 (0)