@@ -150,12 +150,22 @@ def _train_model(self, i, s_id, df, model_kwargs):
150150 logger .debug (forecast .tail ())
151151
152152 # TODO; could also extract trend and seasonality?
153- cols_to_read = filter (
154- lambda x : x .startswith ("future_regressor" ), forecast .columns
153+ cols_to_read = set (
154+ forecast .columns [forecast .columns .str .startswith ("future_regressor" )]
155+ + ["ds" , "trend" ]
155156 )
156- self .explanations_info [s_id ] = (
157- forecast [cols_to_read ].rename ({"ds" : "Date" }, axis = 1 ).set_index ("Date" )
157+ cols_to_read = cols_to_read - {
158+ "future_regressors_additive" ,
159+ "future_regressors_multiplicative" ,
160+ }
161+ combine_terms = cols_to_read - set (self .accepted_regressors [s_id ])
162+ temp_df = (
163+ forecast [list (cols_to_read )]
164+ .rename ({"ds" : "Date" }, axis = 1 )
165+ .set_index ("Date" )
158166 )
167+ temp_df [self .spec .target_column ] = temp_df [combine_terms ].sum (axis = 1 )
168+ self .explanations_info [s_id ] = temp_df .drop (combine_terms , axis = 1 )
159169
160170 self .outputs [s_id ] = forecast
161171 self .forecast_output .populate_series_output (
@@ -457,19 +467,14 @@ def explain_model(self):
457467 for s_id , expl_df in self .explanations_info .items ():
458468 expl_df = expl_df .rename (rename_cols , axis = 1 )
459469 # Local Expl
460- self .local_explanation [s_id ] = self .get_horizon (expl_df ).drop (
461- ["future_regressors_additive" ], axis = 1
462- )
470+ self .local_explanation [s_id ] = self .get_horizon (expl_df )
463471 self .local_explanation [s_id ]["Series" ] = s_id
464472 self .local_explanation [s_id ].index .rename (self .dt_column_name , inplace = True )
465473 # Global Expl
466474 g_expl = self .drop_horizon (expl_df ).mean ()
467475 g_expl .name = s_id
468476 global_expl .append (g_expl )
469477 self .global_explanation = pd .concat (global_expl , axis = 1 )
470- self .global_explanation = self .global_explanation .drop (
471- index = ["future_regressors_additive" ], axis = 0
472- )
473478 self .formatted_global_explanation = (
474479 self .global_explanation / self .global_explanation .sum (axis = 0 ) * 100
475480 )
0 commit comments