@@ -64,11 +64,11 @@ def test_sinkhorn_lpl1_transport_class():
6464 assert_equal (transp_Xs .shape , Xs .shape )
6565
6666 # test unsupervised vs semi-supervised mode
67- otda_unsup = ot .da .SinkhornTransport ()
68- otda_unsup .fit (Xs = Xs , Xt = Xt )
67+ otda_unsup = ot .da .SinkhornLpl1Transport ()
68+ otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
6969 n_unsup = np .sum (otda_unsup .cost_ )
7070
71- otda_semi = ot .da .SinkhornTransport ()
71+ otda_semi = ot .da .SinkhornLpl1Transport ()
7272 otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
7373 assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
7474 n_semisup = np .sum (otda_semi .cost_ )
@@ -136,11 +136,11 @@ def test_sinkhorn_l1l2_transport_class():
136136 assert_equal (transp_Xs .shape , Xs .shape )
137137
138138 # test unsupervised vs semi-supervised mode
139- otda_unsup = ot .da .SinkhornTransport ()
140- otda_unsup .fit (Xs = Xs , Xt = Xt )
139+ otda_unsup = ot .da .SinkhornL1l2Transport ()
140+ otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
141141 n_unsup = np .sum (otda_unsup .cost_ )
142142
143- otda_semi = ot .da .SinkhornTransport ()
143+ otda_semi = ot .da .SinkhornL1l2Transport ()
144144 otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
145145 assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
146146 n_semisup = np .sum (otda_semi .cost_ )
@@ -152,7 +152,9 @@ def test_sinkhorn_l1l2_transport_class():
152152 # and labeled target samples
153153 mass_semi = np .sum (
154154 otda_semi .coupling_ [otda_semi .cost_ == otda_semi .limit_max ])
155- assert mass_semi == 0 , "semisupervised mode not working"
155+ mass_semi = otda_semi .coupling_ [otda_semi .cost_ == otda_semi .limit_max ]
156+ assert_allclose (mass_semi , np .zeros_like (mass_semi ),
157+ rtol = 1e-9 , atol = 1e-9 )
156158
157159 # check everything runs well with log=True
158160 otda = ot .da .SinkhornL1l2Transport (log = True )
@@ -289,11 +291,11 @@ def test_emd_transport_class():
289291 assert_equal (transp_Xs .shape , Xs .shape )
290292
291293 # test unsupervised vs semi-supervised mode
292- otda_unsup = ot .da .SinkhornTransport ()
293- otda_unsup .fit (Xs = Xs , Xt = Xt )
294+ otda_unsup = ot .da .EMDTransport ()
295+ otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
294296 n_unsup = np .sum (otda_unsup .cost_ )
295297
296- otda_semi = ot .da .SinkhornTransport ()
298+ otda_semi = ot .da .EMDTransport ()
297299 otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
298300 assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
299301 n_semisup = np .sum (otda_semi .cost_ )
@@ -305,7 +307,11 @@ def test_emd_transport_class():
305307 # and labeled target samples
306308 mass_semi = np .sum (
307309 otda_semi .coupling_ [otda_semi .cost_ == otda_semi .limit_max ])
308- assert mass_semi == 0 , "semisupervised mode not working"
310+ mass_semi = otda_semi .coupling_ [otda_semi .cost_ == otda_semi .limit_max ]
311+
312+ # we need to use a small tolerance here, otherwise the test breaks
313+ assert_allclose (mass_semi , np .zeros_like (mass_semi ),
314+ rtol = 1e-2 , atol = 1e-2 )
309315
310316
311317def test_mapping_transport_class ():
@@ -491,3 +497,4 @@ def test_otda():
491497# test_sinkhorn_l1l2_transport_class()
492498# test_sinkhorn_lpl1_transport_class()
493499# test_mapping_transport_class()
500+
0 commit comments