Skip to content

Commit bfca4eb

Browse files
authored
Merge pull request #235 from acrellin/download_features_models
Allow computed features and models to be downloaded
2 parents 86c4f95 + 4f00d1a commit bfca4eb

File tree

10 files changed

+138
-42
lines changed

10 files changed

+138
-42
lines changed

cesium_app/app_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ def make_app(cfg, baselayer_handlers, baselayer_settings):
5656
handlers = baselayer_handlers + [
5757
(r'/project(/.*)?', ProjectHandler),
5858
(r'/dataset(/.*)?', DatasetHandler),
59-
(r'/features(/.*)?', FeatureHandler),
60-
(r'/models(/.*)?', ModelHandler),
59+
(r'/features(/[0-9]+)?', FeatureHandler),
60+
(r'/features/([0-9]+)/(download)', FeatureHandler),
61+
(r'/models(/[0-9]+)?', ModelHandler),
62+
(r'/models/([0-9]+)/(download)', ModelHandler),
6163
(r'/predictions(/[0-9]+)?', PredictionHandler),
6264
(r'/predictions/([0-9]+)/(download)', PredictionHandler),
6365
(r'/predict_raw_data', PredictRawDataHandler),

cesium_app/handlers/feature.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,35 @@
1313
from os.path import join as pjoin
1414
import uuid
1515
import datetime
16+
import pandas as pd
1617

1718

1819
class FeatureHandler(BaseHandler):
1920
@auth_or_token
20-
def get(self, featureset_id=None):
21-
if featureset_id is not None:
22-
featureset_info = Featureset.get_if_owned_by(featureset_id,
23-
self.current_user)
21+
def get(self, featureset_id=None, action=None):
22+
if action == 'download':
23+
featureset = Featureset.get_if_owned_by(featureset_id,
24+
self.current_user)
25+
fset_path = featureset.file_uri
26+
fset, data = featurize.load_featureset(fset_path)
27+
if 'labels' in data:
28+
fset['labels'] = data['labels']
29+
self.set_header("Content-Type", 'text/csv; charset="utf-8"')
30+
self.set_header(
31+
"Content-Disposition", "attachment; "
32+
f"filename=cesium_featureset_{featureset.project.name}"
33+
f"_{featureset.name}_{featureset.finished}.csv")
34+
self.write(fset.to_csv(index=True))
2435
else:
25-
featureset_info = [f for p in self.current_user.projects
26-
for f in p.featuresets]
27-
featureset_info.sort(key=lambda f: f.created_at, reverse=True)
28-
29-
self.success(featureset_info)
36+
if featureset_id is not None:
37+
featureset_info = Featureset.get_if_owned_by(featureset_id,
38+
self.current_user)
39+
else:
40+
featureset_info = [f for p in self.current_user.projects
41+
for f in p.featuresets]
42+
featureset_info.sort(key=lambda f: f.created_at, reverse=True)
43+
44+
self.success(featureset_info)
3045

3146
@auth_or_token
3247
async def _await_featurization(self, future, fset):

cesium_app/handlers/model.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import uuid
1414
import datetime
1515

16+
import sklearn
1617
from sklearn.model_selection import GridSearchCV
1718
import joblib
1819

@@ -72,14 +73,28 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
7273

7374
class ModelHandler(BaseHandler):
7475
@auth_or_token
75-
def get(self, model_id=None):
76-
if model_id is not None:
77-
model_info = Model.get_if_owned_by(model_id, self.current_user)
76+
def get(self, model_id=None, action=None):
77+
if action == 'download':
78+
model = Model.get_if_owned_by(model_id, self.current_user)
79+
model_path = model.file_uri
80+
with open(model_path, 'rb') as f:
81+
model_data = f.read()
82+
self.set_header("Content-Type", "application/octet-stream")
83+
self.set_header(
84+
"Content-Disposition", "attachment; "
85+
f"filename=cesium_model_{model.project.name}"
86+
f"_{model.name}_{str(model.finished).replace(' ', 'T')}"
87+
f"_joblib_v{joblib.__version__}"
88+
f"_sklearn_v{sklearn.__version__}.pkl")
89+
self.write(model_data)
7890
else:
79-
model_info = [model for p in self.current_user.projects
80-
for model in p.models]
91+
if model_id is not None:
92+
model_info = Model.get_if_owned_by(model_id, self.current_user)
93+
else:
94+
model_info = [model for p in self.current_user.projects
95+
for model in p.models]
8196

82-
return self.success(model_info)
97+
return self.success(model_info)
8398

8499
@auth_or_token
85100
async def _await_model_statistics(self, model_stats_future, model):

cesium_app/handlers/prediction.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ async def post(self):
121121
@auth_or_token
122122
def get(self, prediction_id=None, action=None):
123123
if action == 'download':
124-
pred_path = Prediction.get_if_owned_by(prediction_id,
125-
self.current_user).file_uri
124+
prediction = Prediction.get_if_owned_by(prediction_id, self.current_user)
125+
pred_path = prediction.file_uri
126126
fset, data = featurize.load_featureset(pred_path)
127127
result = pd.DataFrame(({'label': data['labels']}
128128
if len(data['labels']) > 0 else None),
@@ -133,8 +133,11 @@ def get(self, prediction_id=None, action=None):
133133
result['prediction'] = data['preds']
134134
result.index.name = 'ts_name'
135135
self.set_header("Content-Type", 'text/csv; charset="utf-8"')
136-
self.set_header("Content-Disposition", "attachment; "
137-
"filename=cesium_prediction_results.csv")
136+
self.set_header(
137+
"Content-Disposition", "attachment; "
138+
f"filename=cesium_prediction_results_{prediction.project.name}"
139+
f"_{prediction.dataset.name}"
140+
f"_{prediction.model.name}_{prediction.finished}.csv")
138141
self.write(result.to_csv(index=True))
139142
else:
140143
if prediction_id is None:

cesium_app/tests/frontend/test_predict.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
import pandas as pd
1111
import json
1212
import subprocess
13+
import glob
1314
from cesium_app.model_util import create_token_user
15+
from baselayer.app.config import load_config
16+
17+
18+
cfg = load_config()
1419

1520

1621
def _add_prediction(proj_id, driver):
@@ -154,41 +159,47 @@ def test_download_prediction_csv_class(driver, project, dataset, featureset,
154159
model, prediction):
155160
driver.get('/')
156161
_click_download(project.id, driver)
157-
assert os.path.exists('/tmp/cesium_prediction_results.csv')
162+
matching_downloads_paths = glob.glob(f'{cfg["paths:downloads_folder"]}/'
163+
'cesium_prediction_results*.csv')
164+
assert len(matching_downloads_paths) == 1
158165
try:
159166
npt.assert_equal(
160-
np.genfromtxt('/tmp/cesium_prediction_results.csv', dtype='str'),
167+
np.genfromtxt(matching_downloads_paths[0], dtype='str'),
161168
['ts_name,label,prediction',
162169
'0,Mira,Mira',
163170
'1,Classical_Cepheid,Classical_Cepheid',
164171
'2,Mira,Mira',
165172
'3,Classical_Cepheid,Classical_Cepheid',
166173
'4,Mira,Mira'])
167174
finally:
168-
os.remove('/tmp/cesium_prediction_results.csv')
175+
os.remove(matching_downloads_paths[0])
169176

170177

171178
@pytest.mark.parametrize('model__type', ['LinearSGDClassifier'])
172179
def test_download_prediction_csv_class_unlabeled(driver, project, unlabeled_prediction):
173180
driver.get('/')
174181
_click_download(project.id, driver)
175-
assert os.path.exists('/tmp/cesium_prediction_results.csv')
182+
matching_downloads_paths = glob.glob(f'{cfg["paths:downloads_folder"]}/'
183+
'cesium_prediction_results*.csv')
184+
assert len(matching_downloads_paths) == 1
176185
try:
177-
result = np.genfromtxt('/tmp/cesium_prediction_results.csv', dtype='str')
186+
result = np.genfromtxt(matching_downloads_paths[0], dtype='str')
178187
assert result[0] == 'ts_name,prediction'
179188
assert all([el[0].isdigit() and el[1] == ',' and el[2:] in
180189
['Mira', 'Classical_Cepheid'] for el in result[1:]])
181190
finally:
182-
os.remove('/tmp/cesium_prediction_results.csv')
191+
os.remove(matching_downloads_paths[0])
183192

184193

185194
def test_download_prediction_csv_class_prob(driver, project, dataset,
186195
featureset, model, prediction):
187196
driver.get('/')
188197
_click_download(project.id, driver)
189-
assert os.path.exists('/tmp/cesium_prediction_results.csv')
198+
matching_downloads_paths = glob.glob(f'{cfg["paths:downloads_folder"]}/'
199+
'cesium_prediction_results*.csv')
200+
assert len(matching_downloads_paths) == 1
190201
try:
191-
result = pd.read_csv('/tmp/cesium_prediction_results.csv')
202+
result = pd.read_csv(matching_downloads_paths[0])
192203
npt.assert_array_equal(result.ts_name, np.arange(5))
193204
npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
194205
'Mira', 'Classical_Cepheid',
@@ -198,16 +209,19 @@ def test_download_prediction_csv_class_prob(driver, project, dataset,
198209
[1, 0, 1, 0, 1])
199210
assert (pred_probs.values >= 0.0).all()
200211
finally:
201-
os.remove('/tmp/cesium_prediction_results.csv')
212+
os.remove(matching_downloads_paths[0])
202213

203214

204215
@pytest.mark.parametrize('featureset__name, model__type', [('regr', 'LinearRegressor')])
205-
def test_download_prediction_csv_regr(driver, project, dataset, featureset, model, prediction):
216+
def test_download_prediction_csv_regr(driver, project, dataset, featureset,
217+
model, prediction):
206218
driver.get('/')
207219
_click_download(project.id, driver)
208-
assert os.path.exists('/tmp/cesium_prediction_results.csv')
220+
matching_downloads_paths = glob.glob(f'{cfg["paths:downloads_folder"]}/'
221+
'cesium_prediction_results*.csv')
222+
assert len(matching_downloads_paths) == 1
209223
try:
210-
results = np.genfromtxt('/tmp/cesium_prediction_results.csv',
224+
results = np.genfromtxt(matching_downloads_paths[0],
211225
dtype='str', delimiter=',')
212226
npt.assert_equal(results[0],
213227
['ts_name', 'label', 'prediction'])
@@ -219,7 +233,7 @@ def test_download_prediction_csv_regr(driver, project, dataset, featureset, mode
219233
[3, 2.2, 2.2],
220234
[4, 3.1, 3.1]])
221235
finally:
222-
os.remove('/tmp/cesium_prediction_results.csv')
236+
os.remove(matching_downloads_paths[0])
223237

224238

225239
def test_predict_specific_ts_name(driver, project, dataset, featureset, model):

static/css/base.css

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,7 @@ body {
3333
.loginBox .logo {
3434
float: left;
3535
padding: 1em;
36-
}
36+
}
37+
a:hover {
38+
cursor:pointer;
39+
}

static/js/components/Download.jsx

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import React from 'react';
2+
import PropTypes from 'prop-types';
3+
4+
5+
const Download = (props) => {
6+
const style = {
7+
display: 'inline-block'
8+
};
9+
return (
10+
<a
11+
href={props.url}
12+
style={style}
13+
onClick={
14+
(e) => {
15+
e.stopPropagation();
16+
}
17+
}
18+
>
19+
Download
20+
</a>
21+
);
22+
};
23+
Download.propTypes = {
24+
url: PropTypes.string.isRequired
25+
};
26+
27+
export default Download;

static/js/components/Features.jsx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Plot from './Plot';
1313
import FoldableRow from './FoldableRow';
1414
import { reformatDatetime, contains } from '../utils';
1515
import Delete from './Delete';
16+
import Download from './Download';
1617

1718
const { Tab, Tabs, TabList, TabPanel } = { ...ReactTabs };
1819

@@ -257,7 +258,14 @@ export let FeatureTable = props => (
257258
<td>{featureset.name}</td>
258259
<td>{reformatDatetime(featureset.created_at)}</td>
259260
{status}
260-
<td><DeleteFeatureset ID={featureset.id} /></td>
261+
<td>
262+
{
263+
done &&
264+
<Download url={`/features/${featureset.id}/download`} />
265+
}
266+
&nbsp;&nbsp;
267+
<DeleteFeatureset ID={featureset.id} />
268+
</td>
261269
</tr>
262270
{foldedContent}
263271
</FoldableRow>
@@ -290,7 +298,6 @@ FeatureTable = connect(ftMapStateToProps)(FeatureTable);
290298
const mapDispatchToProps = dispatch => (
291299
{ delete: id => dispatch(Action.deleteFeatureset(id)) }
292300
);
293-
294301
const DeleteFeatureset = connect(null, mapDispatchToProps)(Delete);
295302

296303
export default FeaturesTab;

static/js/components/Models.jsx

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import * as Validate from '../validate';
1010
import * as Action from '../actions';
1111
import Expand from './Expand';
1212
import Delete from './Delete';
13+
import Download from './Download';
1314
import { $try, reformatDatetime } from '../utils';
1415
import FoldableRow from './FoldableRow';
1516

@@ -237,7 +238,14 @@ export let ModelTable = props => (
237238
<td>{model.name}</td>
238239
<td>{reformatDatetime(model.created_at)}</td>
239240
{status}
240-
<td><DeleteModel ID={model.id} /></td>
241+
<td>
242+
{
243+
done &&
244+
<Download url={`/models/${model.id}/download`} />
245+
}
246+
&nbsp;&nbsp;
247+
<DeleteModelButton ID={model.id} />
248+
</td>
241249
</tr>
242250
{foldedContent}
243251
</FoldableRow>
@@ -265,11 +273,10 @@ const mtMapStateToProps = (state, ownProps) => (
265273
ModelTable = connect(mtMapStateToProps)(ModelTable);
266274

267275

268-
const dmMapDispatchToProps = dispatch => (
276+
const deleteMapDispatchToProps = dispatch => (
269277
{ delete: id => dispatch(Action.deleteModel(id)) }
270278
);
271-
272-
const DeleteModel = connect(null, dmMapDispatchToProps)(Delete);
279+
const DeleteModelButton = connect(null, deleteMapDispatchToProps)(Delete);
273280

274281

275282
export default ModelsTab;

static/js/components/Predictions.jsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ let PredictionsTable = props => (
135135
<td>{reformatDatetime(prediction.created_at)}</td>
136136
{status}
137137
<td>
138-
<DownloadPredCSV ID={prediction.id} />
138+
{
139+
done &&
140+
<DownloadPredCSV ID={prediction.id} />
141+
}
139142
&nbsp;&nbsp;
140143
<DeletePrediction ID={prediction.id} />
141144
</td>

0 commit comments

Comments
 (0)