11# -*- coding: utf-8 -*-
22"""
3- Utility functions for GPU
3+ Utility functions for GPU
44"""
55
66# Author: Remi Flamary <remi.flamary@unice.fr>
99#
1010# License: MIT License
1111
12- import cupy as np # np used for matrix computation
13- import cupy as cp # cp used for cupy specific operations
14-
12+ import cupy as np # np used for matrix computation
13+ import cupy as cp # cp used for cupy specific operations
1514
1615
1716def euclidean_distances (a , b , squared = False , to_numpy = True ):
@@ -34,23 +33,24 @@ def euclidean_distances(a, b, squared=False, to_numpy=True):
3433 c : (n x m) np.ndarray or cupy.ndarray
3534 pairwise euclidean distance distance matrix
3635 """
37-
36+
3837 a , b = to_gpu (a , b )
39-
40- a2 = np .sum (np .square (a ),1 )
41- b2 = np .sum (np .square (b ),1 )
42-
43- c = - 2 * np .dot (a ,b .T )
44- c += a2 [:,None ]
45- c += b2 [None ,:]
46-
38+
39+ a2 = np .sum (np .square (a ), 1 )
40+ b2 = np .sum (np .square (b ), 1 )
41+
42+ c = - 2 * np .dot (a , b .T )
43+ c += a2 [:, None ]
44+ c += b2 [None , :]
45+
4746 if not squared :
4847 np .sqrt (c , out = c )
4948 if to_numpy :
5049 return to_np (c )
5150 else :
5251 return c
5352
53+
5454def dist (x1 , x2 = None , metric = 'sqeuclidean' , to_numpy = True ):
5555 """Compute distance between samples in x1 and x2 on gpu
5656
@@ -61,8 +61,8 @@ def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True):
6161 matrix with n1 samples of size d
6262 x2 : np.array (n2,d), optional
6363 matrix with n2 samples of size d (if None then x2=x1)
64- metric : str
65- Metric from 'sqeuclidean', 'euclidean',
64+ metric : str
65+ Metric from 'sqeuclidean', 'euclidean',
6666
6767
6868 Returns
@@ -80,7 +80,6 @@ def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True):
8080 return euclidean_distances (x1 , x2 , squared = False , to_numpy = to_numpy )
8181 else :
8282 raise NotImplementedError
83-
8483
8584
8685def to_gpu (* args ):
@@ -91,10 +90,9 @@ def to_gpu(*args):
9190 return cp .asarray (args [0 ])
9291
9392
94-
9593def to_np (* args ):
9694 """ convert GPU arras to numpy and return them"""
9795 if len (args ) > 1 :
9896 return (cp .asnumpy (x ) for x in args )
9997 else :
100- return cp .asnumpy (args [0 ])
98+ return cp .asnumpy (args [0 ])
0 commit comments