@@ -67,11 +67,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
6767 model = GridSearchCV (model , params_to_optimize ,
6868 n_jobs = n_jobs )
6969 model .fit (fset , data ['labels' ])
70- score = model .score (fset , data ['labels' ])
70+
71+ metrics = {}
72+ metrics ['train_score' ] = model .score (fset , data ['labels' ])
73+
7174 best_params = model .best_params_ if params_to_optimize else {}
7275 joblib .dump (model , model_path )
7376
74- return score , best_params
77+ if model_type == 'RandomForestClassifier' :
78+ if params_to_optimize :
79+ model = model .best_estimator_
80+ if hasattr (model , 'oob_score_' ):
81+ metrics ['oob_score' ] = model .oob_score_
82+ if hasattr (model , 'feature_importances_' ):
83+ metrics ['feature_importances' ] = dict (zip (
84+ fset .columns .get_level_values (0 ).tolist (),
85+ model .feature_importances_ .tolist ()))
86+
87+ return metrics , best_params
7588
7689
7790class ModelHandler (BaseHandler ):
@@ -102,12 +115,12 @@ def get(self, model_id=None, action=None):
102115 @auth_or_token
103116 async def _await_model_statistics (self , model_stats_future , model ):
104117 try :
105- score , best_params = await model_stats_future
118+ model_metrics , best_params = await model_stats_future
106119
107120 model = DBSession ().merge (model )
108121 model .task_id = None
109122 model .finished = datetime .datetime .now ()
110- model .train_score = score
123+ model .metrics = model_metrics
111124 model .params .update (best_params )
112125 DBSession ().commit ()
113126
0 commit comments