@@ -151,7 +151,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
151151
152152 The algorithm used for solving the problem is the block coordinate
153153 descent that alternates between updates of G (using conditionnal gradient)
154- abd the update of L using a classical least square solver.
154+ and the update of L using a classical least square solver.
155155
156156
157157 Parameters
@@ -320,7 +320,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
320320
321321 The algorithm used for solving the problem is the block coordinate
322322 descent that alternates between updates of G (using conditionnal gradient)
323- abd the update of L using a classical kernel least square solver.
323+ and the update of L using a classical kernel least square solver.
324324
325325
326326 Parameters
@@ -492,7 +492,15 @@ def df(G):
492492
493493
494494class OTDA (object ):
495- """Class for domain adaptation with optimal transport"""
495+ """Class for domain adaptation with optimal transport as proposed in [5]
496+
497+
498+ References
499+ ----------
500+
501+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
502+
503+ """
496504
497505 def __init__ (self ,metric = 'sqeuclidean' ):
498506 """ Class initialization"""
@@ -504,8 +512,7 @@ def __init__(self,metric='sqeuclidean'):
504512
505513
506514 def fit (self ,xs ,xt ,ws = None ,wt = None ):
507- """ Fit domain adaptation between samples is xs and xt (with optional
508- weights)"""
515+ """ Fit domain adaptation between samples is xs and xt (with optional weights)"""
509516 self .xs = xs
510517 self .xt = xt
511518
@@ -522,7 +529,7 @@ def fit(self,xs,xt,ws=None,wt=None):
522529 self .computed = True
523530
524531 def interp (self ,direction = 1 ):
525- """Barycentric interpolation for the source (1) or target (-1)
532+ """Barycentric interpolation for the source (1) or target (-1) samples
526533
527534 This Barycentric interpolation solves for each source (resp target)
528535 sample xs (resp xt) the following optimization problem:
@@ -558,10 +565,16 @@ def interp(self,direction=1):
558565
559566
560567 def predict (self ,x ,direction = 1 ):
561- """ Out of sample mapping using the formulation from Ferradans
568+ """ Out of sample mapping using the formulation from [6]
569+
570+ For each sample x to map, it finds the nearest source sample xs and
571+ map the samle x to the position xst+(x-xs) wher xst is the barycentric
572+ interpolation of source sample xs.
573+
574+ References
575+ ----------
562576
563- It basically find the source sample the nearset to the nex sample and
564- apply the difference to the displaced source sample.
577+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
565578
566579 """
567580 if direction > 0 : # >0 then source to target
@@ -582,8 +595,7 @@ class OTDA_sinkhorn(OTDA):
582595 """Class for domain adaptation with optimal transport with entropic regularization"""
583596
584597 def fit (self ,xs ,xt ,reg = 1 ,ws = None ,wt = None ,** kwargs ):
585- """ Fit domain adaptation between samples is xs and xt (with optional
586- weights)"""
598+ """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
587599 self .xs = xs
588600 self .xt = xt
589601
@@ -601,12 +613,12 @@ def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
601613
602614
603615class OTDA_lpl1 (OTDA ):
604- """Class for domain adaptation with optimal transport with entropic an group regularization"""
616+ """Class for domain adaptation with optimal transport with entropic and group regularization"""
605617
606618
607619 def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,** kwargs ):
608- """ Fit domain adaptation between samples is xs and xt (with optional
609- weights) """
620+ """ Fit regularized domain adaptation between samples is xs and xt (with optional weights),
621+ See ot.da.sinkhorn_lpl1_mm for fit parameters" "" "
610622 self .xs = xs
611623 self .xt = xt
612624
@@ -623,7 +635,7 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
623635 self .computed = True
624636
625637class OTDA_mapping_linear (OTDA ):
626- """Class for optimal transport with joint linear mapping estimation"""
638+ """Class for optimal transport with joint linear mapping estimation as in [8] """
627639
628640
629641 def __init__ (self ):
@@ -657,12 +669,7 @@ def mapping(self):
657669
658670
659671 def predict (self ,x ):
660- """ Out of sample mapping using the formulation from Ferradans
661-
662- It basically find the source sample the nearset to the nex sample and
663- apply the difference to the displaced source sample.
664-
665- """
672+ """ Out of sample mapping estimated during the call to fit"""
666673 if self .computed :
667674 if self .bias :
668675 x = np .hstack ((x ,np .ones ((x .shape [0 ],1 ))))
@@ -672,13 +679,12 @@ def predict(self,x):
672679 return None
673680
674681class OTDA_mapping_kernel (OTDA_mapping_linear ):
675- """Class for optimal transport with joint linear mapping estimation"""
682+ """Class for optimal transport with joint nonlinear mapping estimation as in [8] """
676683
677684
678685
679686 def fit (self ,xs ,xt ,mu = 1 ,eta = 1 ,bias = False ,kerneltype = 'gaussian' ,sigma = 1 ,** kwargs ):
680- """ Fit domain adaptation between samples is xs and xt (with optional
681- weights)"""
687+ """ Fit domain adaptation between samples is xs and xt """
682688 self .xs = xs
683689 self .xt = xt
684690 self .bias = bias
@@ -695,12 +701,7 @@ def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs)
695701
696702
697703 def predict (self ,x ):
698- """ Out of sample mapping using the formulation from Ferradans
699-
700- It basically find the source sample the nearset to the nex sample and
701- apply the difference to the displaced source sample.
702-
703- """
704+ """ Out of sample mapping estimated during the call to fit"""
704705
705706 if self .computed :
706707 K = kernel (x ,self .xs ,method = self .kernel ,sigma = self .sigma ,** self .kwargs )
0 commit comments