Skip to content

Commit 363c5f9

Browse files
committed
doc string + example
1 parent 30bfc5c commit 363c5f9

File tree

2 files changed

+196
-18
lines changed

2 files changed

+196
-18
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
============================================
4+
OTDA unsupervised vs semi-supervised setting
5+
============================================
6+
7+
This example introduces a semi supervised domain adaptation in a 2D setting.
8+
It explicits the problem of semi supervised domain adaptation and introduces
9+
some optimal transport approaches to solve it.
10+
11+
Quantities such as optimal couplings, greater coupling coefficients and
12+
transported samples are represented in order to give a visual understanding
13+
of what the transport methods are doing.
14+
"""
15+
16+
# Authors: Remi Flamary <remi.flamary@unice.fr>
17+
# Stanislas Chambon <stan.chambon@gmail.com>
18+
#
19+
# License: MIT License
20+
21+
import matplotlib.pylab as pl
22+
import ot
23+
24+
25+
##############################################################################
26+
# generate data
27+
##############################################################################
28+
29+
n_samples_source = 150
30+
n_samples_target = 150
31+
32+
Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
33+
Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
34+
35+
# Cost matrix
36+
M = ot.dist(Xs, Xt, metric='sqeuclidean')
37+
38+
39+
##############################################################################
40+
# Transport source samples onto target samples
41+
##############################################################################
42+
43+
# unsupervised domain adaptation
44+
ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
45+
ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
46+
transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
47+
48+
# semi-supervised domain adaptation
49+
ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
50+
ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
51+
transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
52+
53+
# semi supervised DA uses available labaled target samples to modify the cost
54+
# matrix involved in the OT problem. The cost of transporting a source sample
55+
# of class A onto a target sample of class B != A is set to infinite, or a
56+
# very large value
57+
58+
59+
##############################################################################
60+
# Fig 1 : plots source and target samples + matrix of pairwise distance
61+
##############################################################################
62+
63+
pl.figure(1, figsize=(10, 10))
64+
pl.subplot(2, 2, 1)
65+
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
66+
pl.xticks([])
67+
pl.yticks([])
68+
pl.legend(loc=0)
69+
pl.title('Source samples')
70+
71+
pl.subplot(2, 2, 2)
72+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
73+
pl.xticks([])
74+
pl.yticks([])
75+
pl.legend(loc=0)
76+
pl.title('Target samples')
77+
78+
pl.subplot(2, 2, 3)
79+
pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
80+
pl.xticks([])
81+
pl.yticks([])
82+
pl.title('Cost matrix - unsupervised DA')
83+
84+
pl.subplot(2, 2, 4)
85+
pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
86+
pl.xticks([])
87+
pl.yticks([])
88+
pl.title('Cost matrix - semisupervised DA')
89+
90+
pl.tight_layout()
91+
92+
# the optimal coupling in the semi-supervised DA case will exhibit " shape
93+
# similar" to the cost matrix, (block diagonal matrix)
94+
95+
##############################################################################
96+
# Fig 2 : plots optimal couplings for the different methods
97+
##############################################################################
98+
99+
pl.figure(2, figsize=(8, 4))
100+
101+
pl.subplot(1, 2, 1)
102+
pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
103+
pl.xticks([])
104+
pl.yticks([])
105+
pl.title('Optimal coupling\nUnsupervised DA')
106+
107+
pl.subplot(1, 2, 2)
108+
pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
109+
pl.xticks([])
110+
pl.yticks([])
111+
pl.title('Optimal coupling\nSemi-supervised DA')
112+
113+
pl.tight_layout()
114+
115+
116+
##############################################################################
117+
# Fig 3 : plot transported samples
118+
##############################################################################
119+
120+
# display transported samples
121+
pl.figure(4, figsize=(8, 4))
122+
pl.subplot(1, 2, 1)
123+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
124+
label='Target samples', alpha=0.5)
125+
pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
126+
marker='+', label='Transp samples', s=30)
127+
pl.title('Transported samples\nEmdTransport')
128+
pl.legend(loc=0)
129+
pl.xticks([])
130+
pl.yticks([])
131+
132+
pl.subplot(1, 2, 2)
133+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
134+
label='Target samples', alpha=0.5)
135+
pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
136+
marker='+', label='Transp samples', s=30)
137+
pl.title('Transported samples\nSinkhornTransport')
138+
pl.xticks([])
139+
pl.yticks([])
140+
141+
pl.tight_layout()
142+
pl.show()

ot/da.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -966,8 +966,12 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
966966
The class labels
967967
Xt : array-like, shape (n_target_samples, n_features)
968968
The training input samples.
969-
yt : array-like, shape (n_labeled_target_samples,)
970-
The class labels
969+
yt : array-like, shape (n_target_samples,)
970+
The class labels. If some target samples are unlabeled, fill the
971+
yt's elements with -1.
972+
973+
Warning: Note that, due to this convention -1 cannot be used as a
974+
class label
971975
972976
Returns
973977
-------
@@ -1023,8 +1027,12 @@ def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
10231027
The class labels
10241028
Xt : array-like, shape (n_target_samples, n_features)
10251029
The training input samples.
1026-
yt : array-like, shape (n_labeled_target_samples,)
1027-
The class labels
1030+
yt : array-like, shape (n_target_samples,)
1031+
The class labels. If some target samples are unlabeled, fill the
1032+
yt's elements with -1.
1033+
1034+
Warning: Note that, due to this convention -1 cannot be used as a
1035+
class label
10281036
10291037
Returns
10301038
-------
@@ -1045,8 +1053,12 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
10451053
The class labels
10461054
Xt : array-like, shape (n_target_samples, n_features)
10471055
The training input samples.
1048-
yt : array-like, shape (n_labeled_target_samples,)
1049-
The class labels
1056+
yt : array-like, shape (n_target_samples,)
1057+
The class labels. If some target samples are unlabeled, fill the
1058+
yt's elements with -1.
1059+
1060+
Warning: Note that, due to this convention -1 cannot be used as a
1061+
class label
10501062
batch_size : int, optional (default=128)
10511063
The batch size for out of sample inverse transform
10521064
@@ -1110,8 +1122,12 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11101122
The class labels
11111123
Xt : array-like, shape (n_target_samples, n_features)
11121124
The training input samples.
1113-
yt : array-like, shape (n_labeled_target_samples,)
1114-
The class labels
1125+
yt : array-like, shape (n_target_samples,)
1126+
The class labels. If some target samples are unlabeled, fill the
1127+
yt's elements with -1.
1128+
1129+
Warning: Note that, due to this convention -1 cannot be used as a
1130+
class label
11151131
batch_size : int, optional (default=128)
11161132
The batch size for out of sample inverse transform
11171133
@@ -1241,8 +1257,12 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
12411257
The class labels
12421258
Xt : array-like, shape (n_target_samples, n_features)
12431259
The training input samples.
1244-
yt : array-like, shape (n_labeled_target_samples,)
1245-
The class labels
1260+
yt : array-like, shape (n_target_samples,)
1261+
The class labels. If some target samples are unlabeled, fill the
1262+
yt's elements with -1.
1263+
1264+
Warning: Note that, due to this convention -1 cannot be used as a
1265+
class label
12461266
12471267
Returns
12481268
-------
@@ -1333,8 +1353,12 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13331353
The class labels
13341354
Xt : array-like, shape (n_target_samples, n_features)
13351355
The training input samples.
1336-
yt : array-like, shape (n_labeled_target_samples,)
1337-
The class labels
1356+
yt : array-like, shape (n_target_samples,)
1357+
The class labels. If some target samples are unlabeled, fill the
1358+
yt's elements with -1.
1359+
1360+
Warning: Note that, due to this convention -1 cannot be used as a
1361+
class label
13381362
13391363
Returns
13401364
-------
@@ -1434,8 +1458,12 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14341458
The class labels
14351459
Xt : array-like, shape (n_target_samples, n_features)
14361460
The training input samples.
1437-
yt : array-like, shape (n_labeled_target_samples,)
1438-
The class labels
1461+
yt : array-like, shape (n_target_samples,)
1462+
The class labels. If some target samples are unlabeled, fill the
1463+
yt's elements with -1.
1464+
1465+
Warning: Note that, due to this convention -1 cannot be used as a
1466+
class label
14391467
14401468
Returns
14411469
-------
@@ -1545,8 +1573,12 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15451573
The class labels
15461574
Xt : array-like, shape (n_target_samples, n_features)
15471575
The training input samples.
1548-
yt : array-like, shape (n_labeled_target_samples,)
1549-
The class labels
1576+
yt : array-like, shape (n_target_samples,)
1577+
The class labels. If some target samples are unlabeled, fill the
1578+
yt's elements with -1.
1579+
1580+
Warning: Note that, due to this convention -1 cannot be used as a
1581+
class label
15501582
15511583
Returns
15521584
-------
@@ -1662,8 +1694,12 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
16621694
The class labels
16631695
Xt : array-like, shape (n_target_samples, n_features)
16641696
The training input samples.
1665-
yt : array-like, shape (n_labeled_target_samples,)
1666-
The class labels
1697+
yt : array-like, shape (n_target_samples,)
1698+
The class labels. If some target samples are unlabeled, fill the
1699+
yt's elements with -1.
1700+
1701+
Warning: Note that, due to this convention -1 cannot be used as a
1702+
class label
16671703
16681704
Returns
16691705
-------

0 commit comments

Comments
 (0)