1010import pandas as pd
1111import json
1212import subprocess
13+ import glob
1314from cesium_app .model_util import create_token_user
15+ from baselayer .app .config import load_config
16+
17+
18+ cfg = load_config ()
1419
1520
1621def _add_prediction (proj_id , driver ):
@@ -154,41 +159,47 @@ def test_download_prediction_csv_class(driver, project, dataset, featureset,
154159 model , prediction ):
155160 driver .get ('/' )
156161 _click_download (project .id , driver )
157- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
162+ matching_downloads_paths = glob .glob (f'{ cfg ["paths:downloads_folder" ]} /'
163+ 'cesium_prediction_results*.csv' )
164+ assert len (matching_downloads_paths ) == 1
158165 try :
159166 npt .assert_equal (
160- np .genfromtxt ('/tmp/cesium_prediction_results.csv' , dtype = 'str' ),
167+ np .genfromtxt (matching_downloads_paths [ 0 ] , dtype = 'str' ),
161168 ['ts_name,label,prediction' ,
162169 '0,Mira,Mira' ,
163170 '1,Classical_Cepheid,Classical_Cepheid' ,
164171 '2,Mira,Mira' ,
165172 '3,Classical_Cepheid,Classical_Cepheid' ,
166173 '4,Mira,Mira' ])
167174 finally :
168- os .remove ('/tmp/cesium_prediction_results.csv' )
175+ os .remove (matching_downloads_paths [ 0 ] )
169176
170177
171178@pytest .mark .parametrize ('model__type' , ['LinearSGDClassifier' ])
172179def test_download_prediction_csv_class_unlabeled (driver , project , unlabeled_prediction ):
173180 driver .get ('/' )
174181 _click_download (project .id , driver )
175- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
182+ matching_downloads_paths = glob .glob (f'{ cfg ["paths:downloads_folder" ]} /'
183+ 'cesium_prediction_results*.csv' )
184+ assert len (matching_downloads_paths ) == 1
176185 try :
177- result = np .genfromtxt ('/tmp/cesium_prediction_results.csv' , dtype = 'str' )
186+ result = np .genfromtxt (matching_downloads_paths [ 0 ] , dtype = 'str' )
178187 assert result [0 ] == 'ts_name,prediction'
179188 assert all ([el [0 ].isdigit () and el [1 ] == ',' and el [2 :] in
180189 ['Mira' , 'Classical_Cepheid' ] for el in result [1 :]])
181190 finally :
182- os .remove ('/tmp/cesium_prediction_results.csv' )
191+ os .remove (matching_downloads_paths [ 0 ] )
183192
184193
185194def test_download_prediction_csv_class_prob (driver , project , dataset ,
186195 featureset , model , prediction ):
187196 driver .get ('/' )
188197 _click_download (project .id , driver )
189- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
198+ matching_downloads_paths = glob .glob (f'{ cfg ["paths:downloads_folder" ]} /'
199+ 'cesium_prediction_results*.csv' )
200+ assert len (matching_downloads_paths ) == 1
190201 try :
191- result = pd .read_csv ('/tmp/cesium_prediction_results.csv' )
202+ result = pd .read_csv (matching_downloads_paths [ 0 ] )
192203 npt .assert_array_equal (result .ts_name , np .arange (5 ))
193204 npt .assert_array_equal (result .label , ['Mira' , 'Classical_Cepheid' ,
194205 'Mira' , 'Classical_Cepheid' ,
@@ -198,16 +209,19 @@ def test_download_prediction_csv_class_prob(driver, project, dataset,
198209 [1 , 0 , 1 , 0 , 1 ])
199210 assert (pred_probs .values >= 0.0 ).all ()
200211 finally :
201- os .remove ('/tmp/cesium_prediction_results.csv' )
212+ os .remove (matching_downloads_paths [ 0 ] )
202213
203214
204215@pytest .mark .parametrize ('featureset__name, model__type' , [('regr' , 'LinearRegressor' )])
205- def test_download_prediction_csv_regr (driver , project , dataset , featureset , model , prediction ):
216+ def test_download_prediction_csv_regr (driver , project , dataset , featureset ,
217+ model , prediction ):
206218 driver .get ('/' )
207219 _click_download (project .id , driver )
208- assert os .path .exists ('/tmp/cesium_prediction_results.csv' )
220+ matching_downloads_paths = glob .glob (f'{ cfg ["paths:downloads_folder" ]} /'
221+ 'cesium_prediction_results*.csv' )
222+ assert len (matching_downloads_paths ) == 1
209223 try :
210- results = np .genfromtxt ('/tmp/cesium_prediction_results.csv' ,
224+ results = np .genfromtxt (matching_downloads_paths [ 0 ] ,
211225 dtype = 'str' , delimiter = ',' )
212226 npt .assert_equal (results [0 ],
213227 ['ts_name' , 'label' , 'prediction' ])
@@ -219,7 +233,7 @@ def test_download_prediction_csv_regr(driver, project, dataset, featureset, mode
219233 [3 , 2.2 , 2.2 ],
220234 [4 , 3.1 , 3.1 ]])
221235 finally :
222- os .remove ('/tmp/cesium_prediction_results.csv' )
236+ os .remove (matching_downloads_paths [ 0 ] )
223237
224238
225239def test_predict_specific_ts_name (driver , project , dataset , featureset , model ):
0 commit comments