Skip to content

Commit b74f65c

Browse files
committed
Add OOB score and feature importance chart to displayed model metrics
1 parent ee3db0e commit b74f65c

File tree

5 files changed

+61
-9
lines changed

5 files changed

+61
-9
lines changed

cesium_app/handlers/model.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
6363
if params_to_optimize:
6464
model = GridSearchCV(model, params_to_optimize)
6565
model.fit(fset, data['labels'])
66-
score = model.score(fset, data['labels'])
66+
67+
metrics = {}
68+
metrics['train_score'] = model.score(fset, data['labels'])
69+
6770
best_params = model.best_params_ if params_to_optimize else {}
6871
joblib.dump(model, model_path)
6972

70-
return score, best_params
73+
if model_type == 'RandomForestClassifier':
74+
if params_to_optimize:
75+
model = model.best_estimator_
76+
if hasattr(model, 'oob_score_'):
77+
metrics['oob_score'] = model.oob_score_
78+
if hasattr(model, 'feature_importances_'):
79+
metrics['feature_importances'] = dict(zip(
80+
fset.columns.get_level_values(0).tolist(),
81+
model.feature_importances_.tolist()))
82+
83+
return metrics, best_params
7184

7285

7386
class ModelHandler(BaseHandler):
@@ -84,12 +97,12 @@ def get(self, model_id=None):
8497
@auth_or_token
8598
async def _await_model_statistics(self, model_stats_future, model):
8699
try:
87-
score, best_params = await model_stats_future
100+
model_metrics, best_params = await model_stats_future
88101

89102
model = DBSession().merge(model)
90103
model.task_id = None
91104
model.finished = datetime.datetime.now()
92-
model.train_score = score
105+
model.metrics = model_metrics
93106
model.params.update(best_params)
94107
DBSession().commit()
95108

cesium_app/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Model(Base):
8989
file_uri = sa.Column(sa.String(), nullable=True, index=True)
9090
task_id = sa.Column(sa.String())
9191
finished = sa.Column(sa.DateTime)
92-
train_score = sa.Column(sa.Float)
92+
metrics = sa.Column(sa.JSON, nullable=True)
9393

9494
featureset = relationship('Featureset')
9595
project = relationship('Project')

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
"bokehjs": "^0.12.5",
1010
"bootstrap": "^3.3.7",
1111
"bootstrap-css": "^3.0.0",
12+
"chart.js": "^2.7.1",
1213
"css-loader": "^0.26.2",
1314
"exports-loader": "^0.6.4",
1415
"imports-loader": "^0.7.1",
1516
"jquery": "^3.1.1",
1617
"prop-types": "^15.5.10",
1718
"react": "^15.1.0",
19+
"react-chartjs-2": "^2.7.0",
1820
"react-dom": "^15.1.0",
1921
"react-redux": "^5.0.3",
2022
"react-tabs": "^0.8.2",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import React from 'react';
2+
import { HorizontalBar } from 'react-chartjs-2';
3+
4+
5+
const FeatureImportancesBarchart = props => {
6+
const sorted_features = Object.keys(props.data).sort(
7+
(a, b) => props.data[b] - props.data[a]).slice(0, 15);
8+
const values = sorted_features.map(
9+
feature => props.data[feature].toFixed(3));
10+
const data = {
11+
labels: sorted_features,
12+
datasets: [
13+
{
14+
label: 'Feature Importance',
15+
backgroundColor: '#2222ff',
16+
hoverBackgroundColor: '#5555ff',
17+
data: values
18+
}
19+
]
20+
};
21+
22+
return (
23+
<div style={{ height: 300, width: 600 }}>
24+
<HorizontalBar data={data} />
25+
</div>
26+
);
27+
};
28+
29+
export default FeatureImportancesBarchart;

static/js/components/Models.jsx

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import Expand from './Expand';
1212
import Delete from './Delete';
1313
import { $try, reformatDatetime } from '../utils';
1414
import FoldableRow from './FoldableRow';
15+
import FeatureImportances from './FeatureImportances';
1516

1617

1718
const ModelsTab = props => (
@@ -177,7 +178,7 @@ const ModelInfo = props => (
177178
<tr>
178179
<th>Model Type</th>
179180
<th>Hyperparameters</th>
180-
<th>Training Data Score</th>
181+
{Object.keys(props.model.metrics).map(metric => <th>{metric}</th>)}
181182
</tr>
182183
</thead>
183184
<tbody>
@@ -199,9 +200,16 @@ const ModelInfo = props => (
199200
</tbody>
200201
</table>
201202
</td>
202-
<td>
203-
{props.model.train_score}
204-
</td>
203+
{
204+
Object.keys(props.model.metrics).map(metric => (
205+
<td>
206+
{
207+
metric == 'feature_importances' ?
208+
<FeatureImportances data={props.model.metrics[metric]} /> :
209+
props.model.metrics[metric]
210+
}
211+
</td>))
212+
}
205213
</tr>
206214
</tbody>
207215
</table>

0 commit comments

Comments
 (0)