@@ -649,18 +649,19 @@ def _save_model(self, output_dir, storage_options):
649649 storage_options = storage_options ,
650650 )
651651
652+ def _validate_automlx_explanation_mode (self ):
653+ if self .spec .model != SupportedModels .AutoMLX and self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX :
654+ raise ValueError (
655+ "AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
656+ "Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
657+ )
658+
652659 @runtime_dependency (
653660 module = "shap" ,
654661 err_msg = (
655662 "Please run `python3 -m pip install shap` to install the required dependencies for model explanation."
656663 ),
657664 )
658- @runtime_dependency (
659- module = "automlx" ,
660- err_msg = (
661- "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
662- ),
663- )
664665 def explain_model (self ):
665666 """
666667 Generates an explanation for the model by using the SHAP (Shapley Additive exPlanations) library.
@@ -683,53 +684,13 @@ def explain_model(self):
683684 )
684685 ratio = SpeedAccuracyMode .ratio [self .spec .explanations_accuracy_mode ]
685686
687+ # validate the automlx mode is use for automlx model
688+ self ._validate_automlx_explanation_mode ()
689+
686690 for s_id , data_i in self .datasets .get_data_by_series (
687691 include_horizon = False
688692 ).items ():
689- if (
690- self .spec .model == SupportedModels .AutoMLX
691- and self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX
692- ):
693- import automlx
694-
695- explainer = automlx .MLExplainer (
696- self .models [s_id ],
697- self .datasets .additional_data .get_data_for_series (series_id = s_id )
698- .drop (self .spec .datetime_column .name , axis = 1 )
699- .head (- self .spec .horizon )
700- if self .spec .additional_data
701- else None ,
702- pd .DataFrame (data_i [self .spec .target_column ]),
703- task = "forecasting" ,
704- )
705-
706- explanations = explainer .explain_prediction (
707- X = self .datasets .additional_data .get_data_for_series (series_id = s_id )
708- .drop (self .spec .datetime_column .name , axis = 1 )
709- .tail (self .spec .horizon )
710- if self .spec .additional_data
711- else None ,
712- forecast_timepoints = list (range (self .spec .horizon + 1 )),
713- )
714-
715- explanations_df = pd .concat (
716- [exp .to_dataframe () for exp in explanations ]
717- )
718- explanations_df ["row" ] = explanations_df .groupby ("Feature" ).cumcount ()
719- explanations_df = explanations_df .pivot (
720- index = "row" , columns = "Feature" , values = "Attribution"
721- )
722- explanations_df = explanations_df .reset_index (drop = True )
723- self .local_explanation [s_id ] = explanations_df
724- elif (
725- self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX
726- and self .spec .model != SupportedModels .AutoMLX
727- ):
728- raise ValueError (
729- "AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
730- "Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
731- )
732- elif s_id in self .models :
693+ if s_id in self .models :
733694 explain_predict_fn = self .get_explain_predict_fn (series_id = s_id )
734695 data_trimmed = data_i .tail (
735696 max (int (len (data_i ) * ratio ), 5 )
0 commit comments