Skip to content

Commit b7a5888

Browse files
committed
Change download table formatting
1 parent 993cb7e commit b7a5888

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

cesium_app/handlers/prediction.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,17 @@ def get(self, prediction_id=None, action=None):
137137
if action == 'download':
138138
pred_path = self._get_prediction(prediction_id).file.uri
139139
fset, data = featurize.load_featureset(pred_path)
140-
result = pd.DataFrame({'ts_name': fset.index,
141-
'label': data['labels'],
142-
'prediction': data['preds']},
143-
columns=['ts_name', 'label', 'prediction'])
140+
result = pd.DataFrame({'label': data['labels']},
141+
index=fset.index)
144142
if len(data.get('pred_probs', [])) > 0:
145-
result['probability'] = data['pred_probs'].max(axis=1).values
143+
result = pd.concat((result, data['pred_probs']), axis=1)
144+
else:
145+
result['prediction'] = data['preds']
146+
result.index.name = 'ts_name'
146147
self.set_header("Content-Type", 'text/csv; charset="utf-8"')
147148
self.set_header("Content-Disposition", "attachment; "
148149
"filename=cesium_prediction_results.csv")
149-
self.write(result.to_csv(index=False))
150+
self.write(result.to_csv(index=True))
150151
else:
151152
if prediction_id is None:
152153
predictions = [prediction

cesium_app/tests/frontend/test_predict.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,10 @@ def test_download_prediction_csv_class_prob(driver):
200200
npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
201201
'Mira', 'Classical_Cepheid',
202202
'Mira'])
203-
npt.assert_array_equal(result.label, result.prediction)
204-
assert (result.probability >= 0.0).all()
203+
pred_probs = result[['Classical_Cepheid', 'Mira']]
204+
npt.assert_array_equal(np.argmax(pred_probs.values, axis=1),
205+
[1, 0, 1, 0, 1])
206+
assert (pred_probs.values >= 0.0).all()
205207
finally:
206208
os.remove('/tmp/cesium_prediction_results.csv')
207209

0 commit comments

Comments
 (0)