@@ -63,11 +63,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
6363 if params_to_optimize :
6464 model = GridSearchCV (model , params_to_optimize )
6565 model .fit (fset , data ['labels' ])
66- score = model .score (fset , data ['labels' ])
66+
67+ metrics = {}
68+ metrics ['train_score' ] = model .score (fset , data ['labels' ])
69+
6770 best_params = model .best_params_ if params_to_optimize else {}
6871 joblib .dump (model , model_path )
6972
70- return score , best_params
73+ if model_type == 'RandomForestClassifier' :
74+ if params_to_optimize :
75+ model = model .best_estimator_
76+ if hasattr (model , 'oob_score_' ):
77+ metrics ['oob_score' ] = model .oob_score_
78+ if hasattr (model , 'feature_importances_' ):
79+ metrics ['feature_importances' ] = dict (zip (
80+ fset .columns .get_level_values (0 ).tolist (),
81+ model .feature_importances_ .tolist ()))
82+
83+ return metrics , best_params
7184
7285
7386class ModelHandler (BaseHandler ):
@@ -84,12 +97,12 @@ def get(self, model_id=None):
8497 @auth_or_token
8598 async def _await_model_statistics (self , model_stats_future , model ):
8699 try :
87- score , best_params = await model_stats_future
100+ model_metrics , best_params = await model_stats_future
88101
89102 model = DBSession ().merge (model )
90103 model .task_id = None
91104 model .finished = datetime .datetime .now ()
92- model .train_score = score
105+ model .metrics = model_metrics
93106 model .params .update (best_params )
94107 DBSession ().commit ()
95108
0 commit comments