@@ -36,8 +36,10 @@ def test_sinkhorn_lpl1_transport_class():
3636 # test margin constraints
3737 mu_s = unif (ns )
3838 mu_t = unif (nt )
39- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
40- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
39+ assert_allclose (
40+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
41+ assert_allclose (
42+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
4143
4244 # test transform
4345 transp_Xs = otda .transform (Xs = Xs )
@@ -108,8 +110,10 @@ def test_sinkhorn_l1l2_transport_class():
108110 # test margin constraints
109111 mu_s = unif (ns )
110112 mu_t = unif (nt )
111- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
112- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
113+ assert_allclose (
114+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
115+ assert_allclose (
116+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
113117
114118 # test transform
115119 transp_Xs = otda .transform (Xs = Xs )
@@ -187,8 +191,10 @@ def test_sinkhorn_transport_class():
187191 # test margin constraints
188192 mu_s = unif (ns )
189193 mu_t = unif (nt )
190- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
191- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
194+ assert_allclose (
195+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
196+ assert_allclose (
197+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
192198
193199 # test transform
194200 transp_Xs = otda .transform (Xs = Xs )
@@ -263,8 +269,10 @@ def test_emd_transport_class():
263269 # test margin constraints
264270 mu_s = unif (ns )
265271 mu_t = unif (nt )
266- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
267- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
272+ assert_allclose (
273+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
274+ assert_allclose (
275+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
268276
269277 # test transform
270278 transp_Xs = otda .transform (Xs = Xs )
@@ -342,8 +350,10 @@ def test_mapping_transport_class():
342350 # test margin constraints
343351 mu_s = unif (ns )
344352 mu_t = unif (nt )
345- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
346- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
353+ assert_allclose (
354+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
355+ assert_allclose (
356+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
347357
348358 # test transform
349359 transp_Xs = otda .transform (Xs = Xs )
@@ -363,8 +373,10 @@ def test_mapping_transport_class():
363373 # test margin constraints
364374 mu_s = unif (ns )
365375 mu_t = unif (nt )
366- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
367- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
376+ assert_allclose (
377+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
378+ assert_allclose (
379+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
368380
369381 # test transform
370382 transp_Xs = otda .transform (Xs = Xs )
@@ -389,8 +401,10 @@ def test_mapping_transport_class():
389401 # test margin constraints
390402 mu_s = unif (ns )
391403 mu_t = unif (nt )
392- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
393- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
404+ assert_allclose (
405+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
406+ assert_allclose (
407+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
394408
395409 # test transform
396410 transp_Xs = otda .transform (Xs = Xs )
@@ -410,8 +424,10 @@ def test_mapping_transport_class():
410424 # test margin constraints
411425 mu_s = unif (ns )
412426 mu_t = unif (nt )
413- assert_allclose (np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
414- assert_allclose (np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
427+ assert_allclose (
428+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
429+ assert_allclose (
430+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
415431
416432 # test transform
417433 transp_Xs = otda .transform (Xs = Xs )
@@ -454,7 +470,8 @@ def test_otda():
454470 da_entrop .interp ()
455471 da_entrop .predict (xs )
456472
457- np .testing .assert_allclose (a , np .sum (da_entrop .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
473+ np .testing .assert_allclose (
474+ a , np .sum (da_entrop .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
458475 np .testing .assert_allclose (b , np .sum (da_entrop .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
459476
460477 # non-convex Group lasso regularization
@@ -488,13 +505,3 @@ def test_otda():
488505 da_emd = ot .da .OTDA_mapping_kernel () # init class
489506 da_emd .fit (xs , xt , numItermax = 10 ) # fit distributions
490507 da_emd .predict (xs ) # interpolation of source samples
491-
492-
493- # if __name__ == "__main__":
494-
495- # test_sinkhorn_transport_class()
496- # test_emd_transport_class()
497- # test_sinkhorn_l1l2_transport_class()
498- # test_sinkhorn_lpl1_transport_class()
499- # test_mapping_transport_class()
500-
0 commit comments