1717from inspect import signature
1818from .backend import get_backend , Backend , NumpyBackend , JaxBackend
1919
20- __time_tic_toc = time .time ()
20+ __time_tic_toc = time .perf_counter ()
2121
2222
2323def tic ():
2424 r"""Python implementation of Matlab tic() function"""
2525 global __time_tic_toc
26- __time_tic_toc = time .time ()
26+ __time_tic_toc = time .perf_counter ()
2727
2828
2929def toc (message = "Elapsed time : {} s" ):
3030 r"""Python implementation of Matlab toc() function"""
31- t = time .time ()
31+ t = time .perf_counter ()
3232 print (message .format (t - __time_tic_toc ))
3333 return t - __time_tic_toc
3434
3535
3636def toq ():
3737 r"""Python implementation of Julia toc() function"""
38- t = time .time ()
38+ t = time .perf_counter ()
3939 return t - __time_tic_toc
4040
4141
@@ -251,7 +251,7 @@ def clean_zeros(a, b, M):
251251 return a2 , b2 , M2
252252
253253
254- def euclidean_distances (X , Y , squared = False ):
254+ def euclidean_distances (X , Y , squared = False , nx = None ):
255255 r"""
256256 Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
257257 distance matrix between each pair of vectors.
@@ -270,13 +270,13 @@ def euclidean_distances(X, Y, squared=False):
270270 -------
271271 distances : array-like, shape (`n_samples_1`, `n_samples_2`)
272272 """
273-
274- nx = get_backend (X , Y )
273+ if nx is None :
274+ nx = get_backend (X , Y )
275275
276276 a2 = nx .einsum ("ij,ij->i" , X , X )
277277 b2 = nx .einsum ("ij,ij->i" , Y , Y )
278278
279- c = - 2 * nx .dot (X , Y . T )
279+ c = - 2 * nx .dot (X , nx . transpose ( Y ) )
280280 c += a2 [:, None ]
281281 c += b2 [None , :]
282282
@@ -291,11 +291,21 @@ def euclidean_distances(X, Y, squared=False):
291291 return c
292292
293293
294- def dist (x1 , x2 = None , metric = "sqeuclidean" , p = 2 , w = None ):
294+ def dist (
295+ x1 ,
296+ x2 = None ,
297+ metric = "sqeuclidean" ,
298+ p = 2 ,
299+ w = None ,
300+ backend = "auto" ,
301+ nx = None ,
302+ use_tensor = False ,
303+ ):
295304 r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
296305
297306 .. note:: This function is backend-compatible and will work on arrays
298- from all compatible backends.
307+ from all compatible backends for the following metrics:
308+ 'sqeuclidean', 'euclidean', 'cityblock', 'minkowski', 'cosine', 'correlation'.
299309
300310 Parameters
301311 ----------
@@ -315,7 +325,17 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
315325 p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
316326 w : array-like, rank 1
317327 Weights for the weighted metrics.
318-
328+ backend : str, optional
329+ Backend to use for the computation. If 'auto', the backend is
330+ automatically selected based on the input data. if 'scipy',
331+ the ``scipy.spatial.distance.cdist`` function is used (and gradients are
332+ detached).
333+ use_tensor : bool, optional
334+ If true use tensorized computation for the distance matrix which can
335+ cause memory issues for large datasets. Default is False and the
336+ parameter is used only for the 'cityblock' and 'minkowski' metrics.
337+ nx : Backend, optional
338+ Backend to perform computations on. If omitted, the backend defaults to that of `x1`.
319339
320340 Returns
321341 -------
@@ -324,12 +344,69 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
324344 distance matrix computed with given metric
325345
326346 """
347+ if nx is None :
348+ nx = get_backend (x1 , x2 )
327349 if x2 is None :
328350 x2 = x1
329- if metric == "sqeuclidean" :
330- return euclidean_distances (x1 , x2 , squared = True )
351+ if backend == "scipy" : # force scipy backend with cdist function
352+ x1 = nx .to_numpy (x1 )
353+ x2 = nx .to_numpy (x2 )
354+ if isinstance (metric , str ) and metric .endswith ("minkowski" ):
355+ return nx .from_numpy (cdist (x1 , x2 , metric = metric , p = p , w = w ))
356+ if w is not None :
357+ return nx .from_numpy (cdist (x1 , x2 , metric = metric , w = w ))
358+ return nx .from_numpy (cdist (x1 , x2 , metric = metric ))
359+ elif metric == "sqeuclidean" :
360+ return euclidean_distances (x1 , x2 , squared = True , nx = nx )
331361 elif metric == "euclidean" :
332- return euclidean_distances (x1 , x2 , squared = False )
362+ return euclidean_distances (x1 , x2 , squared = False , nx = nx )
363+ elif metric == "cityblock" :
364+ if use_tensor :
365+ return nx .sum (nx .abs (x1 [:, None , :] - x2 [None , :, :]), axis = 2 )
366+ else :
367+ M = 0.0
368+ for i in range (x1 .shape [1 ]):
369+ M += nx .abs (x1 [:, i ][:, None ] - x2 [:, i ][None , :])
370+ return M
371+ elif metric == "minkowski" :
372+ if w is None :
373+ if use_tensor :
374+ return nx .power (
375+ nx .sum (
376+ nx .power (nx .abs (x1 [:, None , :] - x2 [None , :, :]), p ), axis = 2
377+ ),
378+ 1 / p ,
379+ )
380+ else :
381+ M = 0.0
382+ for i in range (x1 .shape [1 ]):
383+ M += nx .abs (x1 [:, i ][:, None ] - x2 [:, i ][None , :]) ** p
384+ return M ** (1 / p )
385+ else :
386+ if use_tensor :
387+ return nx .power (
388+ nx .sum (
389+ w [None , None , :]
390+ * nx .power (nx .abs (x1 [:, None , :] - x2 [None , :, :]), p ),
391+ axis = 2 ,
392+ ),
393+ 1 / p ,
394+ )
395+ else :
396+ M = 0.0
397+ for i in range (x1 .shape [1 ]):
398+ M += w [i ] * nx .abs (x1 [:, i ][:, None ] - x2 [:, i ][None , :]) ** p
399+ return M ** (1 / p )
400+ elif metric == "cosine" :
401+ nx1 = nx .sqrt (nx .einsum ("ij,ij->i" , x1 , x1 ))
402+ nx2 = nx .sqrt (nx .einsum ("ij,ij->i" , x2 , x2 ))
403+ return 1.0 - (nx .dot (x1 , nx .transpose (x2 )) / nx1 [:, None ] / nx2 [None , :])
404+ elif metric == "correlation" :
405+ x1 = x1 - nx .mean (x1 , axis = 1 )[:, None ]
406+ x2 = x2 - nx .mean (x2 , axis = 1 )[:, None ]
407+ nx1 = nx .sqrt (nx .einsum ("ij,ij->i" , x1 , x1 ))
408+ nx2 = nx .sqrt (nx .einsum ("ij,ij->i" , x2 , x2 ))
409+ return 1.0 - (nx .dot (x1 , nx .transpose (x2 )) / nx1 [:, None ] / nx2 [None , :])
333410 else :
334411 if not get_backend (x1 , x2 ).__name__ == "numpy" :
335412 raise NotImplementedError ()
0 commit comments