@@ -666,10 +666,11 @@ def explain_model(self):
666666 lambda x : x .timestamp ()
667667 )
668668
669- # Explainer fails when boolean columns are passed
670- _ , data_trimmed_encoded = _label_encode_dataframe (
671- data_trimmed , no_encode = {datetime_col_name , self .original_target_column }
672- )
669+ # Explainer fails when boolean columns are passed for arima
670+ if self .spec .model == SupportedModels .Arima :
671+ _ , data_trimmed_encoded = _label_encode_dataframe (
672+ data_trimmed , no_encode = {datetime_col_name , self .original_target_column }
673+ )
673674
674675 kernel_explnr = PermutationExplainer (
675676 model = explain_predict_fn , masker = data_trimmed_encoded
@@ -714,15 +715,17 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
714715 kernel_explainer: The kernel explainer object to use for generating explanations.
715716 """
716717 data = self .datasets .get_horizon_at_series (s_id = series_id )
718+ # columns that were dropped in train_model in arima, should be dropped here as well
717719 if self .spec .model == SupportedModels .Arima and series_id in self .constant_cols :
718720 data = data .drop (columns = self .constant_cols [series_id ])
719721 data [datetime_col_name ] = datetime_to_seconds (data [datetime_col_name ])
720722 data = data .reset_index (drop = True )
721723
722- # Explainer fails when boolean columns are passed
723- _ , data = _label_encode_dataframe (
724- data , no_encode = {datetime_col_name , self .original_target_column }
725- )
724+ # Explainer fails when boolean columns are passed for arima
725+ if self .spec .model == SupportedModels .Arima :
726+ _ , data = _label_encode_dataframe (
727+ data , no_encode = {datetime_col_name , self .original_target_column }
728+ )
726729 # Generate local SHAP values using the kernel explainer
727730 local_kernel_explnr_vals = kernel_explainer .shap_values (data )
728731
0 commit comments