@@ -16,50 +16,62 @@ def test_jit_compile():
1616 ot (x , y , regularization = 1.0 , seed = 0 , max_steps = 10 )
1717
1818
19- def test_shapes ():
19+ @pytest .mark .parametrize ("method" , ["log_sinkhorn" , "sinkhorn" ])
20+ def test_shapes (method ):
2021 x = keras .random .normal ((128 , 8 ), seed = 0 )
2122 y = keras .random .normal ((128 , 8 ), seed = 1 )
2223
23- ox , oy = optimal_transport (x , y , regularization = 1.0 , seed = 0 , max_steps = 10 )
24+ ox , oy = optimal_transport (x , y , regularization = 1.0 , seed = 0 , max_steps = 10 , method = method )
2425
2526 assert keras .ops .shape (ox ) == keras .ops .shape (x )
2627 assert keras .ops .shape (oy ) == keras .ops .shape (y )
2728
2829
2930def test_transport_cost_improves ():
30- x = keras .random .normal ((1024 , 2 ), seed = 0 )
31- y = keras .random .normal ((1024 , 2 ), seed = 1 )
31+ x = keras .random .normal ((128 , 2 ), seed = 0 )
32+ y = keras .random .normal ((128 , 2 ), seed = 1 )
3233
3334 before_cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
3435
35- x , y = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = None )
36+ x , y = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = 1000 )
3637
3738 after_cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
3839
3940 assert after_cost < before_cost
4041
4142
43+ @pytest .mark .skip (reason = "too unreliable" )
4244def test_assignment_is_optimal ():
43- x = keras .ops .convert_to_tensor (
44- [
45- [- 1 , 2 ],
46- [- 1 , 1 ],
47- [- 1 , 0 ],
48- [- 1 , - 1 ],
49- [- 1 , - 2 ],
50- ]
51- )
52- optimal_y = keras .ops .convert_to_tensor (
53- [
54- [1 , 2 ],
55- [1 , 1 ],
56- [1 , 0 ],
57- [1 , - 1 ],
58- [1 , - 2 ],
59- ]
60- )
61- y = keras .random .shuffle (optimal_y , axis = 0 , seed = 0 )
62-
63- x , y = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = None , scale_regularization = False )
64-
65- assert_allclose (x , y )
45+ x = keras .random .normal ((16 , 2 ), seed = 0 )
46+ p = keras .random .shuffle (keras .ops .arange (keras .ops .shape (x )[0 ]), seed = 0 )
47+ optimal_assignments = keras .ops .argsort (p )
48+
49+ y = x [p ]
50+
51+ x , y , assignments = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = 10_000 , return_assignments = True )
52+
53+ assert_allclose (assignments , optimal_assignments )
54+
55+
56+ def test_assignment_aligns_with_pot ():
57+ try :
58+ from ot .bregman import sinkhorn_log
59+ except (ImportError , ModuleNotFoundError ):
60+ pytest .skip ("Need to install POT to run this test." )
61+
62+ x = keras .random .normal ((16 , 2 ), seed = 0 )
63+ p = keras .random .shuffle (keras .ops .arange (keras .ops .shape (x )[0 ]), seed = 0 )
64+ y = x [p ]
65+
66+ a = keras .ops .ones (keras .ops .shape (x )[0 ])
67+ b = keras .ops .ones (keras .ops .shape (y )[0 ])
68+ M = x [:, None ] - y [None , :]
69+ M = keras .ops .norm (M , axis = - 1 )
70+
71+ pot_plan = sinkhorn_log (a , b , M , reg = 1e-3 , numItermax = 10_000 , stopThr = 1e-99 )
72+ pot_assignments = keras .random .categorical (pot_plan , num_samples = 1 , seed = 0 )
73+ pot_assignments = keras .ops .squeeze (pot_assignments , axis = - 1 )
74+
75+ _ , _ , assignments = optimal_transport (x , y , regularization = 1e-3 , seed = 0 , max_steps = 10_000 , return_assignments = True )
76+
77+ assert_allclose (pot_assignments , assignments )
0 commit comments