77#
88# License: MIT License
99
10+ import multiprocessing
11+
1012import numpy as np
13+
1114# import compiled emd
12- from .emd_wrap import emd_c , emd2_c
15+ from .emd_wrap import emd_c , check_result
1316from ..utils import parmap
14- import multiprocessing
1517
1618
17- def emd (a , b , M , numItermax = 100000 ):
19+ def emd (a , b , M , numItermax = 100000 , log = False ):
1820 """Solves the Earth Movers distance problem and returns the OT matrix
1921
2022
@@ -42,11 +44,17 @@ def emd(a, b, M, numItermax=100000):
4244 numItermax : int, optional (default=100000)
4345 The maximum number of iterations before stopping the optimization
4446 algorithm if it has not converged.
47+ log: boolean, optional (default=False)
48+ If True, returns a dictionary containing the cost and dual
49+ variables. Otherwise returns only the optimal transportation matrix.
4550
4651 Returns
4752 -------
4853 gamma: (ns x nt) ndarray
4954 Optimal transportation matrix for the given parameters
55+ log: dict
56+ If input log is true, a dictionary containing the cost and dual
57+ variables and exit status
5058
5159
5260 Examples
@@ -82,14 +90,24 @@ def emd(a, b, M, numItermax=100000):
8290
8391 # if empty array given then use unifor distributions
8492 if len (a ) == 0 :
85- a = np .ones ((M .shape [0 ], ), dtype = np .float64 )/ M .shape [0 ]
93+ a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
8694 if len (b ) == 0 :
87- b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
88-
89- return emd_c (a , b , M , numItermax )
90-
91-
92- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), numItermax = 100000 ):
95+ b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
96+
97+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax )
98+ result_code_string = check_result (result_code )
99+ if log :
100+ log = {}
101+ log ['cost' ] = cost
102+ log ['u' ] = u
103+ log ['v' ] = v
104+ log ['warning' ] = result_code_string
105+ log ['result_code' ] = result_code
106+ return G , log
107+ return G
108+
109+
110+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), numItermax = 100000 , log = False , return_matrix = False ):
93111 """Solves the Earth Movers distance problem and returns the loss
94112
95113 .. math::
@@ -116,11 +134,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
116134 numItermax : int, optional (default=100000)
117135 The maximum number of iterations before stopping the optimization
118136 algorithm if it has not converged.
137+ log: boolean, optional (default=False)
138+ If True, returns a dictionary containing the cost and dual
139+ variables. Otherwise returns only the optimal transportation cost.
140+ return_matrix: boolean, optional (default=False)
141+ If True, returns the optimal transportation matrix in the log.
119142
120143 Returns
121144 -------
122145 gamma: (ns x nt) ndarray
123146 Optimal transportation matrix for the given parameters
147+ log: dict
148+ If input log is true, a dictionary containing the cost and dual
149+ variables and exit status
124150
125151
126152 Examples
@@ -156,17 +182,31 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
156182
157183 # if empty array given then use unifor distributions
158184 if len (a ) == 0 :
159- a = np .ones ((M .shape [0 ], ), dtype = np .float64 )/ M .shape [0 ]
185+ a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
160186 if len (b ) == 0 :
161- b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
187+ b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
162188
163- if len (b .shape ) == 1 :
164- return emd2_c (a , b , M , numItermax )
189+ if log or return_matrix :
190+ def f (b ):
191+ G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
192+ result_code_string = check_result (resultCode )
193+ log = {}
194+ if return_matrix :
195+ log ['G' ] = G
196+ log ['u' ] = u
197+ log ['v' ] = v
198+ log ['warning' ] = result_code_string
199+ log ['result_code' ] = resultCode
200+ return [cost , log ]
165201 else :
166- nb = b .shape [1 ]
167- # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
168-
169202 def f (b ):
170- return emd2_c (a , b , M , numItermax )
171- res = parmap (f , [b [:, i ] for i in range (nb )], processes )
172- return np .array (res )
203+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax )
204+ check_result (result_code )
205+ return cost
206+
207+ if len (b .shape ) == 1 :
208+ return f (b )
209+ nb = b .shape [1 ]
210+
211+ res = parmap (f , [b [:, i ] for i in range (nb )], processes )
212+ return res
0 commit comments