Skip to content

Commit b12edc5

Browse files
committed
integrated test for semi supervised case
1 parent 669a6be commit b12edc5

File tree

1 file changed

+60
-36
lines changed

1 file changed

+60
-36
lines changed

test/test_da.py

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,25 @@ def test_sinkhorn_lpl1_transport_class():
6363
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
6464
assert_equal(transp_Xs.shape, Xs.shape)
6565

66-
# test semi supervised mode
67-
clf = ot.da.SinkhornLpl1Transport()
68-
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
69-
n_unsup = np.sum(clf.cost_)
66+
# test unsupervised vs semi-supervised mode
67+
clf_unsup = ot.da.SinkhornTransport()
68+
clf_unsup.fit(Xs=Xs, Xt=Xt)
69+
n_unsup = np.sum(clf_unsup.cost_)
7070

71-
# test semi supervised mode
72-
clf = ot.da.SinkhornLpl1Transport()
73-
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
74-
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
75-
n_semisup = np.sum(clf.cost_)
71+
clf_semi = ot.da.SinkhornTransport()
72+
clf_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
73+
assert_equal(clf_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
74+
n_semisup = np.sum(clf_semi.cost_)
7675

76+
# check that the cost matrix norms are indeed different
7777
assert n_unsup != n_semisup, "semisupervised mode not working"
7878

79+
# check that the coupling forbids mass transport between labeled source
80+
# and labeled target samples
81+
mass_semi = np.sum(
82+
clf_semi.coupling_[clf_semi.cost_ == clf_semi.limit_max])
83+
assert mass_semi == 0, "semisupervised mode not working"
84+
7985

8086
def test_sinkhorn_l1l2_transport_class():
8187
"""test_sinkhorn_transport
@@ -129,19 +135,25 @@ def test_sinkhorn_l1l2_transport_class():
129135
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
130136
assert_equal(transp_Xs.shape, Xs.shape)
131137

132-
# test semi supervised mode
133-
clf = ot.da.SinkhornL1l2Transport()
134-
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
135-
n_unsup = np.sum(clf.cost_)
138+
# test unsupervised vs semi-supervised mode
139+
clf_unsup = ot.da.SinkhornTransport()
140+
clf_unsup.fit(Xs=Xs, Xt=Xt)
141+
n_unsup = np.sum(clf_unsup.cost_)
136142

137-
# test semi supervised mode
138-
clf = ot.da.SinkhornL1l2Transport()
139-
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
140-
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
141-
n_semisup = np.sum(clf.cost_)
143+
clf_semi = ot.da.SinkhornTransport()
144+
clf_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
145+
assert_equal(clf_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
146+
n_semisup = np.sum(clf_semi.cost_)
142147

148+
# check that the cost matrix norms are indeed different
143149
assert n_unsup != n_semisup, "semisupervised mode not working"
144150

151+
# check that the coupling forbids mass transport between labeled source
152+
# and labeled target samples
153+
mass_semi = np.sum(
154+
clf_semi.coupling_[clf_semi.cost_ == clf_semi.limit_max])
155+
assert mass_semi == 0, "semisupervised mode not working"
156+
145157
# check everything runs well with log=True
146158
clf = ot.da.SinkhornL1l2Transport(log=True)
147159
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
@@ -200,19 +212,25 @@ def test_sinkhorn_transport_class():
200212
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
201213
assert_equal(transp_Xs.shape, Xs.shape)
202214

203-
# test semi supervised mode
204-
clf = ot.da.SinkhornTransport()
205-
clf.fit(Xs=Xs, Xt=Xt)
206-
n_unsup = np.sum(clf.cost_)
215+
# test unsupervised vs semi-supervised mode
216+
clf_unsup = ot.da.SinkhornTransport()
217+
clf_unsup.fit(Xs=Xs, Xt=Xt)
218+
n_unsup = np.sum(clf_unsup.cost_)
207219

208-
# test semi supervised mode
209-
clf = ot.da.SinkhornTransport()
210-
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
211-
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
212-
n_semisup = np.sum(clf.cost_)
220+
clf_semi = ot.da.SinkhornTransport()
221+
clf_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
222+
assert_equal(clf_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
223+
n_semisup = np.sum(clf_semi.cost_)
213224

225+
# check that the cost matrix norms are indeed different
214226
assert n_unsup != n_semisup, "semisupervised mode not working"
215227

228+
# check that the coupling forbids mass transport between labeled source
229+
# and labeled target samples
230+
mass_semi = np.sum(
231+
clf_semi.coupling_[clf_semi.cost_ == clf_semi.limit_max])
232+
assert mass_semi == 0, "semisupervised mode not working"
233+
216234
# check everything runs well with log=True
217235
clf = ot.da.SinkhornTransport(log=True)
218236
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
@@ -270,19 +288,25 @@ def test_emd_transport_class():
270288
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
271289
assert_equal(transp_Xs.shape, Xs.shape)
272290

273-
# test semi supervised mode
274-
clf = ot.da.EMDTransport()
275-
clf.fit(Xs=Xs, Xt=Xt)
276-
n_unsup = np.sum(clf.cost_)
291+
# test unsupervised vs semi-supervised mode
292+
clf_unsup = ot.da.SinkhornTransport()
293+
clf_unsup.fit(Xs=Xs, Xt=Xt)
294+
n_unsup = np.sum(clf_unsup.cost_)
277295

278-
# test semi supervised mode
279-
clf = ot.da.EMDTransport()
280-
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
281-
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
282-
n_semisup = np.sum(clf.cost_)
296+
clf_semi = ot.da.SinkhornTransport()
297+
clf_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
298+
assert_equal(clf_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
299+
n_semisup = np.sum(clf_semi.cost_)
283300

301+
# check that the cost matrix norms are indeed different
284302
assert n_unsup != n_semisup, "semisupervised mode not working"
285303

304+
# check that the coupling forbids mass transport between labeled source
305+
# and labeled target samples
306+
mass_semi = np.sum(
307+
clf_semi.coupling_[clf_semi.cost_ == clf_semi.limit_max])
308+
assert mass_semi == 0, "semisupervised mode not working"
309+
286310

287311
def test_mapping_transport_class():
288312
"""test_mapping_transport

0 commit comments

Comments
 (0)