88#
99# License: MIT License
1010
11- import numpy as np
12- import cudamat
11+ import cupy as np # np used for matrix computation
12+ import cupy as cp # cp used for cupy specific operations
13+ from . import utils
1314
1415
15- def sinkhorn (a , b , M_GPU , reg , numItermax = 1000 , stopThr = 1e-9 , verbose = False ,
16- log = False , returnAsGPU = False ):
17- r"""
18- Solve the entropic regularization optimal transport problem on GPU
16+
17+ def sinkhorn_knopp (a , b , M , reg , numItermax = 1000 , stopThr = 1e-9 ,
18+ verbose = False , log = False , to_numpy = True , ** kwargs ):
19+ """
20+ Solve the entropic regularization optimal transport problem and return the OT matrix
1921
2022 The function solves the following optimization problem:
2123
@@ -40,9 +42,10 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
4042 ----------
4143 a : np.ndarray (ns,)
4244 samples weights in the source domain
43- b : np.ndarray (nt,)
44- samples in the target domain
45- M_GPU : cudamat.CUDAMatrix (ns,nt)
45+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
46+ samples in the target domain, compute sinkhorn with multiple targets
47+ and fixed M if b is a matrix (return OT loss + dual variables in log)
48+ M : np.ndarray (ns,nt)
4649 loss matrix
4750 reg : float
4851 Regularization term >0
@@ -54,8 +57,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
5457 Print information along iterations
5558 log : bool, optional
5659 record log if True
57- returnAsGPU : bool, optional
58- return the OT matrix as a cudamat.CUDAMatrix
60+
5961
6062 Returns
6163 -------
@@ -88,60 +90,78 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
8890 ot.optim.cg : General regularized OT
8991
9092 """
93+
94+ a = cp .asarray (a , dtype = np .float64 )
95+ b = cp .asarray (b , dtype = np .float64 )
96+ M = cp .asarray (M , dtype = np .float64 )
97+
98+ if len (a ) == 0 :
99+ a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
100+ if len (b ) == 0 :
101+ b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
102+
91103 # init data
92104 Nini = len (a )
93105 Nfin = len (b )
94106
107+ if len (b .shape ) > 1 :
108+ nbb = b .shape [1 ]
109+ else :
110+ nbb = 0
111+
95112 if log :
96113 log = {'err' : []}
97114
98115 # we assume that no distances are null except those of the diagonal of
99116 # distances
100- u = (np .ones (Nini ) / Nini ).reshape ((Nini , 1 ))
101- u_GPU = cudamat .CUDAMatrix (u )
102- a_GPU = cudamat .CUDAMatrix (a .reshape ((Nini , 1 )))
103- ones_GPU = cudamat .empty (u_GPU .shape ).assign (1 )
104- v = (np .ones (Nfin ) / Nfin ).reshape ((Nfin , 1 ))
105- v_GPU = cudamat .CUDAMatrix (v )
106- b_GPU = cudamat .CUDAMatrix (b .reshape ((Nfin , 1 )))
107-
108- M_GPU .divide (- reg )
117+ if nbb :
118+ u = np .ones ((Nini , nbb )) / Nini
119+ v = np .ones ((Nfin , nbb )) / Nfin
120+ else :
121+ u = np .ones (Nini ) / Nini
122+ v = np .ones (Nfin ) / Nfin
109123
110- K_GPU = cudamat . exp ( M_GPU )
124+ # print(reg )
111125
112- ones_GPU .divide (a_GPU , target = a_GPU )
113- Kp_GPU = cudamat .empty (K_GPU .shape )
114- K_GPU .mult_by_col (a_GPU , target = Kp_GPU )
126+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
127+ K = np .empty (M .shape , dtype = M .dtype )
128+ np .divide (M , - reg , out = K )
129+ np .exp (K , out = K )
115130
116- tmp_GPU = cudamat .empty (K_GPU .shape )
131+ # print(np.min(K))
132+ tmp2 = np .empty (b .shape , dtype = M .dtype )
117133
134+ Kp = (1 / a ).reshape (- 1 , 1 ) * K
118135 cpt = 0
119136 err = 1
120137 while (err > stopThr and cpt < numItermax ):
121- uprev_GPU = u_GPU . copy ()
122- vprev_GPU = v_GPU . copy ()
138+ uprev = u
139+ vprev = v
123140
124- KtransposeU_GPU = K_GPU . transpose (). dot (u_GPU )
125- b_GPU .divide (KtransposeU_GPU , target = v_GPU )
126- ones_GPU . divide ( Kp_GPU .dot (v_GPU ), target = u_GPU )
141+ KtransposeU = np . dot (K . T , u )
142+ v = np .divide (b , KtransposeU )
143+ u = 1. / np .dot (Kp , v )
127144
128- if (np .any (KtransposeU_GPU .asarray () == 0 ) or
129- not u_GPU .allfinite () or not v_GPU .allfinite ()):
145+ if (np .any (KtransposeU == 0 ) or
146+ np .any (np .isnan (u )) or np .any (np .isnan (v )) or
147+ np .any (np .isinf (u )) or np .any (np .isinf (v ))):
130148 # we have reached the machine precision
131149 # come back to previous solution and quit loop
132150 print ('Warning: numerical errors at iteration' , cpt )
133- u_GPU = uprev_GPU . copy ()
134- v_GPU = vprev_GPU . copy ()
151+ u = uprev
152+ v = vprev
135153 break
136154 if cpt % 10 == 0 :
137155 # we can speed up the process by checking for the error only all
138156 # the 10th iterations
139- K_GPU .mult_by_col (u_GPU , target = tmp_GPU )
140- tmp_GPU .mult_by_row (v_GPU .transpose (), target = tmp_GPU )
141-
142- bcopy_GPU = b_GPU .copy ().transpose ()
143- bcopy_GPU .add_sums (tmp_GPU , axis = 0 , beta = - 1 )
144- err = bcopy_GPU .euclid_norm ()** 2
157+ if nbb :
158+ err = np .sum ((u - uprev )** 2 ) / np .sum ((u )** 2 ) + \
159+ np .sum ((v - vprev )** 2 ) / np .sum ((v )** 2 )
160+ else :
161+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
162+ tmp2 = np .sum (u [:,None ]* K * v [None ,:],0 )
163+ #tmp2=np.einsum('i,ij,j->j', u, K, v)
164+ err = np .linalg .norm (tmp2 - b )** 2 # violation of marginal
145165 if log :
146166 log ['err' ].append (err )
147167
@@ -150,20 +170,31 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
150170 print (
151171 '{:5s}|{:12s}' .format ('It.' , 'Err' ) + '\n ' + '-' * 19 )
152172 print ('{:5d}|{:8e}|' .format (cpt , err ))
153- cpt += 1
154- if log :
155- log ['u' ] = u_GPU .asarray ()
156- log ['v' ] = v_GPU .asarray ()
157-
158- K_GPU .mult_by_col (u_GPU , target = K_GPU )
159- K_GPU .mult_by_row (v_GPU .transpose (), target = K_GPU )
160-
161- if returnAsGPU :
162- res = K_GPU
163- else :
164- res = K_GPU .asarray ()
165-
173+ cpt = cpt + 1
166174 if log :
167- return res , log
168- else :
169- return res
175+ log ['u' ] = u
176+ log ['v' ] = v
177+
178+ if nbb : # return only loss
179+ #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory)
180+ res = np .empty (nbb )
181+ for i in range (nbb ):
182+ res [i ]= np .sum (u [:,None ,i ]* (K * M )* v [None ,:,i ])
183+ if to_numpy :
184+ res = utils .to_np (res )
185+ if log :
186+ return res , log
187+ else :
188+ return res
189+
190+ else : # return OT matrix
191+ res = u .reshape ((- 1 , 1 )) * K * v .reshape ((1 , - 1 ))
192+ if to_numpy :
193+ res = utils .to_np (res )
194+ if log :
195+ return res , log
196+ else :
197+ return res
198+
199+ # define sinkhorn as sinkhorn_knopp
200+ sinkhorn = sinkhorn_knopp
0 commit comments