Skip to content

Commit 4f00d1a

Browse files
committed
Use downloads path from config in prediction tests
1 parent f53f6da commit 4f00d1a

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

cesium_app/tests/frontend/test_predict.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
import subprocess
1313
import glob
1414
from cesium_app.model_util import create_token_user
15+
from baselayer.app.config import load_config
16+
17+
18+
cfg = load_config()
1519

1620

1721
def _add_prediction(proj_id, driver):
@@ -155,7 +159,8 @@ def test_download_prediction_csv_class(driver, project, dataset, featureset,
155159
model, prediction):
156160
driver.get('/')
157161
_click_download(project.id, driver)
158-
matching_downloads_paths = glob.glob('/tmp/cesium_prediction_results*.csv')
162+
matching_downloads_paths = glob.glob(f'{cfg["paths:downloads_folder"]}/'
163+
'cesium_prediction_results*.csv')
159164
assert len(matching_downloads_paths) == 1
160165
try:
161166
npt.assert_equal(
@@ -174,7 +179,8 @@ def test_download_prediction_csv_class(driver, project, dataset, featureset,
174179
def test_download_prediction_csv_class_unlabeled(driver, project, unlabeled_prediction):
175180
driver.get('/')
176181
_click_download(project.id, driver)
177-
matching_downloads_paths = glob.glob('/tmp/cesium_prediction_results*.csv')
182+
matching_downloads_paths = glob.glob(f'{cfg["paths:downloads_folder"]}/'
183+
'cesium_prediction_results*.csv')
178184
assert len(matching_downloads_paths) == 1
179185
try:
180186
result = np.genfromtxt(matching_downloads_paths[0], dtype='str')
@@ -189,7 +195,8 @@ def test_download_prediction_csv_class_prob(driver, project, dataset,
189195
featureset, model, prediction):
190196
driver.get('/')
191197
_click_download(project.id, driver)
192-
matching_downloads_paths = glob.glob('/tmp/cesium_prediction_results*.csv')
198+
matching_downloads_paths = glob.glob(f'{cfg["paths:downloads_folder"]}/'
199+
'cesium_prediction_results*.csv')
193200
assert len(matching_downloads_paths) == 1
194201
try:
195202
result = pd.read_csv(matching_downloads_paths[0])
@@ -206,10 +213,12 @@ def test_download_prediction_csv_class_prob(driver, project, dataset,
206213

207214

208215
@pytest.mark.parametrize('featureset__name, model__type', [('regr', 'LinearRegressor')])
209-
def test_download_prediction_csv_regr(driver, project, dataset, featureset, model, prediction):
216+
def test_download_prediction_csv_regr(driver, project, dataset, featureset,
217+
model, prediction):
210218
driver.get('/')
211219
_click_download(project.id, driver)
212-
matching_downloads_paths = glob.glob('/tmp/cesium_prediction_results*.csv')
220+
matching_downloads_paths = glob.glob(f'{cfg["paths:downloads_folder"]}/'
221+
'cesium_prediction_results*.csv')
213222
assert len(matching_downloads_paths) == 1
214223
try:
215224
results = np.genfromtxt(matching_downloads_paths[0],

0 commit comments

Comments
 (0)