2222from ..const import (
2323 DEFAULT_TRIALS ,
2424 PROPHET_INTERNAL_DATE_COL ,
25- ForecastOutputColumns ,
2625 SupportedModels ,
2726)
2827from .base_model import ForecastOperatorBaseModel
@@ -123,6 +122,14 @@ def _train_model(self, i, series_id, df, model_kwargs):
123122 upper_bound = self .get_horizon (forecast ["yhat_upper" ]).values ,
124123 lower_bound = self .get_horizon (forecast ["yhat_lower" ]).values ,
125124 )
125+ # Get all features that make up the forecast. Exclude CI (upper/lower) and drop yhat ([:-1])
126+ core_columns = forecast .columns [
127+ ~ forecast .columns .str .endswith ("_lower" )
128+ & ~ forecast .columns .str .endswith ("_upper" )
129+ ][:- 1 ]
130+ self .explanations_info [series_id ] = (
131+ forecast [core_columns ].rename ({"ds" : "Date" }, axis = 1 ).set_index ("Date" )
132+ )
126133
127134 self .models [series_id ] = {}
128135 self .models [series_id ]["model" ] = model
@@ -151,6 +158,7 @@ def _build_model(self) -> pd.DataFrame:
151158 full_data_dict = self .datasets .get_data_by_series ()
152159 self .models = {}
153160 self .outputs = {}
161+ self .explanations_info = {}
154162 self .additional_regressors = self .datasets .get_additional_data_column_names ()
155163 model_kwargs = self .set_kwargs ()
156164 self .forecast_output = ForecastOutput (
@@ -257,6 +265,25 @@ def objective(trial):
257265 model_kwargs_i = study .best_params
258266 return model_kwargs_i
259267
268+ def explain_model (self ):
269+ self .local_explanation = {}
270+ global_expl = []
271+
272+ for s_id , expl_df in self .explanations_info .items ():
273+ # Local Expl
274+ self .local_explanation [s_id ] = self .get_horizon (expl_df )
275+ self .local_explanation [s_id ]["Series" ] = s_id
276+ self .local_explanation [s_id ].index .rename (self .dt_column_name , inplace = True )
277+ # Global Expl
278+ g_expl = self .drop_horizon (expl_df ).mean ()
279+ g_expl .name = s_id
280+ global_expl .append (g_expl )
281+ self .global_explanation = pd .concat (global_expl , axis = 1 )
282+ self .formatted_global_explanation = (
283+ self .global_explanation / self .global_explanation .sum (axis = 0 ) * 100
284+ )
285+ self .formatted_local_explanation = pd .concat (self .local_explanation .values ())
286+
260287 def _generate_report (self ):
261288 import report_creator as rc
262289 from prophet .plot import add_changepoints_to_plot
@@ -335,22 +362,6 @@ def _generate_report(self):
335362 # If the key is present, call the "explain_model" method
336363 self .explain_model ()
337364
338- # Convert the global explanation data to a DataFrame
339- global_explanation_df = pd .DataFrame (self .global_explanation )
340-
341- self .formatted_global_explanation = (
342- global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
343- )
344-
345- aggregate_local_explanations = pd .DataFrame ()
346- for s_id , local_ex_df in self .local_explanation .items ():
347- local_ex_df_copy = local_ex_df .copy ()
348- local_ex_df_copy [ForecastOutputColumns .SERIES ] = s_id
349- aggregate_local_explanations = pd .concat (
350- [aggregate_local_explanations , local_ex_df_copy ], axis = 0
351- )
352- self .formatted_local_explanation = aggregate_local_explanations
353-
354365 if not self .target_cat_col :
355366 self .formatted_global_explanation = (
356367 self .formatted_global_explanation .rename (
@@ -364,7 +375,7 @@ def _generate_report(self):
364375
365376 # Create a markdown section for the global explainability
366377 global_explanation_section = rc .Block (
367- rc .Heading ("Global Explanation of Models " , level = 2 ),
378+ rc .Heading ("Global Explainability " , level = 2 ),
368379 rc .Text (
369380 "The following tables provide the feature attribution for the global explainability."
370381 ),
@@ -373,7 +384,7 @@ def _generate_report(self):
373384
374385 blocks = [
375386 rc .DataTable (
376- local_ex_df .div ( local_ex_df . abs (). sum ( axis = 1 ) , axis = 0 ) * 100 ,
387+ local_ex_df .drop ( "Series" , axis = 1 ) ,
377388 label = s_id if self .target_cat_col else None ,
378389 index = True ,
379390 )
@@ -393,6 +404,8 @@ def _generate_report(self):
393404 # Do not fail the whole run due to explanations failure
394405 logger .warning (f"Failed to generate Explanations with error: { e } ." )
395406 logger .debug (f"Full Traceback: { traceback .format_exc ()} " )
407+ self .errors_dict ["explainer_error" ] = str (e )
408+ self .errors_dict ["explainer_error_error" ] = traceback .format_exc ()
396409
397410 model_description = rc .Text (
398411 """Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects. It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well."""
0 commit comments