4747 SpeedAccuracyMode ,
4848 SupportedMetrics ,
4949 SupportedModels ,
50- BACKTEST_REPORT_NAME
50+ BACKTEST_REPORT_NAME ,
5151)
5252from ..operator_config import ForecastOperatorConfig , ForecastOperatorSpec
5353from .forecast_datasets import ForecastDatasets
@@ -259,7 +259,11 @@ def generate_report(self):
259259 output_dir = self .spec .output_directory .url
260260 file_path = f"{ output_dir } /{ BACKTEST_REPORT_NAME } "
261261 if self .spec .model == AUTO_SELECT :
262- backtest_sections .append (rc .Heading ("Auto-Select Backtesting and Performance Metrics" , level = 2 ))
262+ backtest_sections .append (
263+ rc .Heading (
264+ "Auto-Select Backtesting and Performance Metrics" , level = 2
265+ )
266+ )
263267 if not os .path .exists (file_path ):
264268 failure_msg = rc .Text (
265269 "auto-select could not be executed. Please check the "
@@ -268,15 +272,23 @@ def generate_report(self):
268272 backtest_sections .append (failure_msg )
269273 else :
270274 backtest_stats = pd .read_csv (file_path )
271- model_metric_map = backtest_stats .drop (columns = ['metric' , 'backtest' ])
272- average_dict = {k : round (v , 4 ) for k , v in model_metric_map .mean ().to_dict ().items ()}
275+ model_metric_map = backtest_stats .drop (
276+ columns = ["metric" , "backtest" ]
277+ )
278+ average_dict = {
279+ k : round (v , 4 )
280+ for k , v in model_metric_map .mean ().to_dict ().items ()
281+ }
273282 best_model = min (average_dict , key = average_dict .get )
274283 summary_text = rc .Text (
275284 f"Overall, the average { self .spec .metric } scores for the models are { average_dict } , with"
276- f" { best_model } being identified as the top-performing model during backtesting." )
285+ f" { best_model } being identified as the top-performing model during backtesting."
286+ )
277287 backtest_table = rc .DataTable (backtest_stats , index = True )
278288 liner_plot = get_auto_select_plot (backtest_stats )
279- backtest_sections .extend ([backtest_table , summary_text , liner_plot ])
289+ backtest_sections .extend (
290+ [backtest_table , summary_text , liner_plot ]
291+ )
280292
281293 forecast_plots = []
282294 if len (self .forecast_output .list_series_ids ()) > 0 :
@@ -643,6 +655,12 @@ def _save_model(self, output_dir, storage_options):
643655 "Please run `python3 -m pip install shap` to install the required dependencies for model explanation."
644656 ),
645657 )
658+ @runtime_dependency (
659+ module = "automlx" ,
660+ err_msg = (
661+ "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
662+ ),
663+ )
646664 def explain_model (self ):
647665 """
648666 Generates an explanation for the model by using the SHAP (Shapley Additive exPlanations) library.
@@ -668,7 +686,44 @@ def explain_model(self):
668686 for s_id , data_i in self .datasets .get_data_by_series (
669687 include_horizon = False
670688 ).items ():
671- if s_id in self .models :
689+ if (
690+ self .spec .model == SupportedModels .AutoMLX
691+ and self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX
692+ ):
693+ import automlx
694+
695+ explainer = automlx .MLExplainer (
696+ self .models [s_id ],
697+ self .datasets .additional_data .get_data_for_series (series_id = s_id )
698+ .drop (self .spec .datetime_column .name , axis = 1 )
699+ .head (- self .spec .horizon )
700+ if self .spec .additional_data
701+ else None ,
702+ pd .DataFrame (data_i [self .spec .target_column ]),
703+ task = "forecasting" ,
704+ )
705+
706+ explanations = explainer .explain_prediction (
707+ X = self .datasets .additional_data .get_data_for_series (series_id = s_id )
708+ .drop (self .spec .datetime_column .name , axis = 1 )
709+ .tail (self .spec .horizon )
710+ if self .spec .additional_data
711+ else None ,
712+ forecast_timepoints = list (range (self .spec .horizon + 1 )),
713+ )
714+
715+ explanations_df = pd .concat (
716+ [exp .to_dataframe () for exp in explanations ]
717+ )
718+ explanations_df ["row" ] = explanations_df .groupby ("Feature" ).cumcount ()
719+ explanations_df = explanations_df .pivot (
720+ index = "row" , columns = "Feature" , values = "Attribution"
721+ )
722+ explanations_df = explanations_df .reset_index (drop = True )
723+ # explanations_df[self.spec.datetime_column.name]=self.datasets.additional_data.get_data_for_series(series_id=s_id).tail(self.spec.horizon)[self.spec.datetime_column.name].reset_index(drop=True)
724+ self .local_explanation [s_id ] = explanations_df
725+
726+ elif s_id in self .models :
672727 explain_predict_fn = self .get_explain_predict_fn (series_id = s_id )
673728 data_trimmed = data_i .tail (
674729 max (int (len (data_i ) * ratio ), 5 )
@@ -699,6 +754,14 @@ def explain_model(self):
699754 logger .warn (
700755 "No explanations generated. Ensure that additional data has been provided."
701756 )
757+ elif (
758+ self .spec .model == SupportedModels .AutoMLX
759+ and self .spec .explanations_accuracy_mode
760+ == SpeedAccuracyMode .AUTOMLX
761+ ):
762+ logger .warning (
763+ "Global explanations not available for AutoMLX models with inherent explainability"
764+ )
702765 else :
703766 self .global_explanation [s_id ] = dict (
704767 zip (
0 commit comments