1010import pandas as pd
1111import json
1212import subprocess
13+ import glob
1314from cesium_app .model_util import create_token_user
1415
1516
@@ -154,41 +155,44 @@ def test_download_prediction_csv_class(driver, project, dataset, featureset,
154155 model , prediction ):
155156 driver .get ('/' )
156157 _click_download (project .id , driver )
157- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
158+ matching_downloads_paths = glob .glob ('/tmp/cesium_prediction_results*.csv' )
159+ assert len (matching_downloads_paths ) == 1
158160 try :
159161 npt .assert_equal (
160- np .genfromtxt ('/tmp/cesium_prediction_results.csv' , dtype = 'str' ),
162+ np .genfromtxt (matching_downloads_paths [ 0 ] , dtype = 'str' ),
161163 ['ts_name,label,prediction' ,
162164 '0,Mira,Mira' ,
163165 '1,Classical_Cepheid,Classical_Cepheid' ,
164166 '2,Mira,Mira' ,
165167 '3,Classical_Cepheid,Classical_Cepheid' ,
166168 '4,Mira,Mira' ])
167169 finally :
168- os .remove ('/tmp/cesium_prediction_results.csv' )
170+ os .remove (matching_downloads_paths [ 0 ] )
169171
170172
171173@pytest .mark .parametrize ('model__type' , ['LinearSGDClassifier' ])
172174def test_download_prediction_csv_class_unlabeled (driver , project , unlabeled_prediction ):
173175 driver .get ('/' )
174176 _click_download (project .id , driver )
175- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
177+ matching_downloads_paths = glob .glob ('/tmp/cesium_prediction_results*.csv' )
178+ assert len (matching_downloads_paths ) == 1
176179 try :
177- result = np .genfromtxt ('/tmp/cesium_prediction_results.csv' , dtype = 'str' )
180+ result = np .genfromtxt (matching_downloads_paths [ 0 ] , dtype = 'str' )
178181 assert result [0 ] == 'ts_name,prediction'
179182 assert all ([el [0 ].isdigit () and el [1 ] == ',' and el [2 :] in
180183 ['Mira' , 'Classical_Cepheid' ] for el in result [1 :]])
181184 finally :
182- os .remove ('/tmp/cesium_prediction_results.csv' )
185+ os .remove (matching_downloads_paths [ 0 ] )
183186
184187
185188def test_download_prediction_csv_class_prob (driver , project , dataset ,
186189 featureset , model , prediction ):
187190 driver .get ('/' )
188191 _click_download (project .id , driver )
189- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
192+ matching_downloads_paths = glob .glob ('/tmp/cesium_prediction_results*.csv' )
193+ assert len (matching_downloads_paths ) == 1
190194 try :
191- result = pd .read_csv ('/tmp/cesium_prediction_results.csv' )
195+ result = pd .read_csv (matching_downloads_paths [ 0 ] )
192196 npt .assert_array_equal (result .ts_name , np .arange (5 ))
193197 npt .assert_array_equal (result .label , ['Mira' , 'Classical_Cepheid' ,
194198 'Mira' , 'Classical_Cepheid' ,
@@ -198,16 +202,17 @@ def test_download_prediction_csv_class_prob(driver, project, dataset,
198202 [1 , 0 , 1 , 0 , 1 ])
199203 assert (pred_probs .values >= 0.0 ).all ()
200204 finally :
201- os .remove ('/tmp/cesium_prediction_results.csv' )
205+ os .remove (matching_downloads_paths [ 0 ] )
202206
203207
204208@pytest .mark .parametrize ('featureset__name, model__type' , [('regr' , 'LinearRegressor' )])
205209def test_download_prediction_csv_regr (driver , project , dataset , featureset , model , prediction ):
206210 driver .get ('/' )
207211 _click_download (project .id , driver )
208- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
212+ matching_downloads_paths = glob .glob ('/tmp/cesium_prediction_results*.csv' )
213+ assert len (matching_downloads_paths ) == 1
209214 try :
210- results = np .genfromtxt ('/tmp/cesium_prediction_results.csv' ,
215+ results = np .genfromtxt (matching_downloads_paths [ 0 ] ,
211216 dtype = 'str' , delimiter = ',' )
212217 npt .assert_equal (results [0 ],
213218 ['ts_name' , 'label' , 'prediction' ])
@@ -219,7 +224,7 @@ def test_download_prediction_csv_regr(driver, project, dataset, featureset, mode
219224 [3 , 2.2 , 2.2 ],
220225 [4 , 3.1 , 3.1 ]])
221226 finally :
222- os .remove ('/tmp/cesium_prediction_results.csv' )
227+ os .remove (matching_downloads_paths [ 0 ] )
223228
224229
225230def test_predict_specific_ts_name (driver , project , dataset , featureset , model ):
0 commit comments