66
77import numpy as np
88import ot
9- import time
109import pytest
1110
1211try : # test if cudamat installed
@@ -31,7 +30,11 @@ def test_gpu_dist():
3130
3231 np .testing .assert_allclose (M , M2 , rtol = 1e-10 )
3332
34- M2 = ot .gpu .dist (a .copy (), b .copy (), to_numpy = False )
33+ M2 = ot .gpu .dist (a .copy (), b .copy (), metric = 'euclidean' , to_numpy = False )
34+
35+ # check raise not implemented wrong metric
36+ with pytest .raises (NotImplementedError ):
37+ M2 = ot .gpu .dist (a .copy (), b .copy (), metric = 'cityblock' , to_numpy = False )
3538
3639
3740@pytest .mark .skipif (nogpu , reason = "No GPU available" )
@@ -46,6 +49,9 @@ def test_gpu_sinkhorn():
4649 wa = ot .unif (n_samples // 4 )
4750 wb = ot .unif (n_samples )
4851
52+ wb2 = np .random .rand (n_samples , 20 )
53+ wb2 /= wb2 .sum (0 , keepdims = True )
54+
4955 M = ot .dist (a .copy (), b .copy ())
5056 M2 = ot .gpu .dist (a .copy (), b .copy (), to_numpy = False )
5157
@@ -56,7 +62,11 @@ def test_gpu_sinkhorn():
5662
5763 np .testing .assert_allclose (G1 , G , rtol = 1e-10 )
5864
59- G2 = ot .gpu .sinkhorn (wa , wb , M2 , reg , to_numpy = False )
65+ # run all on gpu
66+ ot .gpu .sinkhorn (wa , wb , M2 , reg , to_numpy = False , log = True )
67+
68+ # run sinkhorn for multiple targets
69+ ot .gpu .sinkhorn (wa , wb2 , M2 , reg , to_numpy = False , log = True )
6070
6171
6272@pytest .mark .skipif (nogpu , reason = "No GPU available" )
@@ -83,4 +93,4 @@ def test_gpu_sinkhorn_lpl1():
8393
8494 np .testing .assert_allclose (G1 , G , rtol = 1e-10 )
8595
86- G2 = ot .gpu .da .sinkhorn_lpl1_mm (wa , labels_a , wb , M2 , reg , to_numpy = False )
96+ ot .gpu .da .sinkhorn_lpl1_mm (wa , labels_a , wb , M2 , reg , to_numpy = False , log = True )
0 commit comments