@@ -11,7 +11,7 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
1111 """
1212 Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm.
1313
14- Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a doubly stochastic
14+ Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a
1515 transport plan, containing assignment probabilities.
1616 The permutation is then sampled randomly according to the transport plan.
1717
@@ -27,12 +27,15 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
2727 :param seed: Random seed to use for sampling indices.
2828 Default: None, which means the seed will be auto-determined for non-compiled contexts.
2929
30- :return: Tensor of shape (m ,)
30+ :return: Tensor of shape (n ,)
3131 Assignment indices for x2.
3232
3333 """
3434 plan = sinkhorn_plan (x1 , x2 , ** kwargs )
35- assignments = keras .random .categorical (plan , num_samples = 1 , seed = seed )
35+
36+ # we sample from log(plan) to receive assignments of length n, corresponding to indices of x2
37+ # such that x2[assignments] matches x1
38+ assignments = keras .random .categorical (keras .ops .log (plan ), num_samples = 1 , seed = seed )
3639 assignments = keras .ops .squeeze (assignments , axis = 1 )
3740
3841 return assignments
@@ -42,7 +45,7 @@ def sinkhorn_plan(
4245 x1 : Tensor ,
4346 x2 : Tensor ,
4447 regularization : float = 1.0 ,
45- max_steps : int = 10_000 ,
48+ max_steps : int = None ,
4649 rtol : float = 1e-5 ,
4750 atol : float = 1e-8 ,
4851) -> Tensor :
@@ -59,7 +62,7 @@ def sinkhorn_plan(
5962 Controls the standard deviation of the Gaussian kernel.
6063
6164 :param max_steps: Maximum number of iterations, or None to run until convergence.
62- Default: 10_000
65+ Default: None
6366
6467 :param rtol: Relative tolerance for convergence.
6568 Default: 1e-5.
@@ -71,17 +74,20 @@ def sinkhorn_plan(
7174 The transport probabilities.
7275 """
7376 cost = euclidean (x1 , x2 )
77+ cost_scaled = - cost / regularization
7478
75- # initialize the transport plan from a gaussian kernel
76- plan = keras .ops .exp (cost / - (regularization * keras .ops .mean (cost ) + 1e-16 ))
79+ # initialize transport plan from a gaussian kernel
80+ # (more numerically stable version of keras.ops.exp(-cost/regularization))
81+ plan = keras .ops .exp (cost_scaled - keras .ops .max (cost_scaled ))
82+ n , m = keras .ops .shape (cost )
7783
7884 def contains_nans (plan ):
7985 return keras .ops .any (keras .ops .isnan (plan ))
8086
8187 def is_converged (plan ):
82- # for convergence, the plan should be doubly stochastic
83- conv0 = keras .ops .all (keras .ops .isclose (keras .ops .sum (plan , axis = 0 ), 1.0 , rtol = rtol , atol = atol ))
84- conv1 = keras .ops .all (keras .ops .isclose (keras .ops .sum (plan , axis = 1 ), 1.0 , rtol = rtol , atol = atol ))
88+ # for convergence, the target marginals must match
89+ conv0 = keras .ops .all (keras .ops .isclose (keras .ops .sum (plan , axis = 0 ), 1.0 / m , rtol = rtol , atol = atol ))
90+ conv1 = keras .ops .all (keras .ops .isclose (keras .ops .sum (plan , axis = 1 ), 1.0 / n , rtol = rtol , atol = atol ))
8591 return conv0 & conv1
8692
8793 def cond (_ , plan ):
@@ -90,8 +96,8 @@ def cond(_, plan):
9096
9197 def body (steps , plan ):
9298 # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
93- plan = keras .ops .softmax (plan , axis = 0 )
94- plan = keras .ops .softmax (plan , axis = 1 )
99+ plan = plan / keras .ops .sum (plan , axis = 0 , keepdims = True ) * ( 1.0 / m )
100+ plan = plan / keras .ops .sum (plan , axis = 1 , keepdims = True ) * ( 1.0 / n )
95101
96102 return steps + 1 , plan
97103
0 commit comments