2020from ads .opctl import logger
2121
2222from .. import utils
23- from ..const import SUMMARY_METRICS_HORIZON_LIMIT , SupportedMetrics , SupportedModels
23+ from ..const import SUMMARY_METRICS_HORIZON_LIMIT , SupportedMetrics , SupportedModels , SpeedAccuracyMode
2424from ..operator_config import ForecastOperatorConfig , ForecastOperatorSpec
2525from ads .common .decorator .runtime_dependency import runtime_dependency
2626from .forecast_datasets import ForecastDatasets , ForecastOutput
@@ -73,10 +73,10 @@ def generate_report(self):
7373 warnings .simplefilter (action = "ignore" , category = ConvergenceWarning )
7474 import datapane as dp
7575
76- # load data and build models
7776 start_time = time .time ()
7877 result_df = self ._build_model ()
7978 elapsed_time = time .time () - start_time
79+ logger .info ("Building the models completed in %s seconds" , elapsed_time )
8080
8181 # Generate metrics
8282 summary_metrics = None
@@ -574,49 +574,37 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
574574 dict: A dictionary containing the global explanation for each feature in the dataset.
575575 The keys are the feature names and the values are the average absolute SHAP values.
576576 """
577- from shap import KernelExplainer
578-
577+ from shap import PermutationExplainer
578+ exp_start_time = time .time ()
579+ global_ex_time = 0
580+ local_ex_time = 0
581+ logger .info (f"Calculating explanations using { self .spec .explanations_accuracy_mode } mode" )
579582 for series_id in self .target_columns :
580583 self .series_id = series_id
581- if self .spec .model == SupportedModels .AutoTS :
582- self .dataset_cols = (
583- self .full_data_long .loc [
584- self .full_data_long .series_id == self .series_id
585- ]
586- .set_index (datetime_col_name )
587- .columns
588- )
589-
590- self .bg_data = self .full_data_long .loc [
591- self .full_data_long .series_id == self .series_id
592- ].set_index (datetime_col_name )
593-
594- else :
595- self .dataset_cols = (
596- self .full_data_dict .get (series_id )
597- .set_index (datetime_col_name )
598- .drop (series_id , axis = 1 )
599- .columns
600- )
601-
602- self .bg_data = self .full_data_dict .get (series_id ).set_index (
603- datetime_col_name
604- )
584+ self .dataset_cols = (
585+ self .full_data_dict .get (series_id )
586+ .set_index (datetime_col_name )
587+ .drop (series_id , axis = 1 )
588+ .columns
589+ )
605590
606- kernel_explnr = KernelExplainer (
591+ self .bg_data = self .full_data_dict .get (series_id ).set_index (
592+ datetime_col_name
593+ )
594+ data = self .bg_data [list (self .dataset_cols )][: - self .spec .horizon ][
595+ list (self .dataset_cols )]
596+ ratio = SpeedAccuracyMode .ratio [self .spec .explanations_accuracy_mode ]
597+ data_trimmed = data .tail (max (int (len (data ) * ratio ), 100 )).reset_index ()
598+ data_trimmed [datetime_col_name ] = data_trimmed [datetime_col_name ].apply (lambda x : x .timestamp ())
599+ kernel_explnr = PermutationExplainer (
607600 model = explain_predict_fn ,
608- data = self .bg_data [list (self .dataset_cols )][: - self .spec .horizon ][
609- list (self .dataset_cols )
610- ],
601+ masker = data_trimmed ,
611602 keep_index = False
612603 if self .spec .model == SupportedModels .AutoMLX
613604 else True ,
614605 )
615606
616- kernel_explnr_vals = kernel_explnr .shap_values (
617- self .bg_data [: - self .spec .horizon ][list (self .dataset_cols )],
618- nsamples = 50 ,
619- )
607+ kernel_explnr_vals = kernel_explnr .shap_values (data_trimmed )
620608
621609 if not len (kernel_explnr_vals ):
622610 logger .warn (
@@ -625,14 +613,19 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
625613 else :
626614 self .global_explanation [series_id ] = dict (
627615 zip (
628- self . dataset_cols ,
629- np .average (np .absolute (kernel_explnr_vals ), axis = 0 ),
616+ data_trimmed . columns [ 1 :] ,
617+ np .average (np .absolute (kernel_explnr_vals [:, 1 :] ), axis = 0 ),
630618 )
631619 )
620+ exp_end_time = time .time ()
621+ global_ex_time = global_ex_time + exp_end_time - exp_start_time
632622
633623 self .local_explainer (
634624 kernel_explnr , series_id = series_id , datetime_col_name = datetime_col_name
635625 )
626+ local_ex_time = local_ex_time + time .time () - exp_end_time
627+ logger .info ("Global explanations generation completed in %s seconds" , global_ex_time )
628+ logger .info ("Local explanations generation completed in %s seconds" , local_ex_time )
636629
637630 def local_explainer (self , kernel_explainer , series_id , datetime_col_name ) -> None :
638631 """
@@ -644,18 +637,19 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
644637 """
645638 # Get the data for the series ID and select the relevant columns
646639 # data = self.full_data_dict.get(series_id).set_index(datetime_col_name)
647- data = self .bg_data [- self .spec .horizon :][list (self .dataset_cols )]
648-
640+ data_horizon = self .bg_data [- self .spec .horizon :][list (self .dataset_cols )]
641+ data = data_horizon .reset_index ()
642+ data [datetime_col_name ] = data [datetime_col_name ].apply (lambda x : x .timestamp ())
649643 # Generate local SHAP values using the kernel explainer
650- local_kernel_explnr_vals = kernel_explainer .shap_values (data , nsamples = 50 )
644+ local_kernel_explnr_vals = kernel_explainer .shap_values (data )
651645
652646 # Convert the SHAP values into a DataFrame
653647 local_kernel_explnr_df = pd .DataFrame (
654- local_kernel_explnr_vals , columns = self . dataset_cols
648+ local_kernel_explnr_vals [:, 1 :], columns = data . columns [ 1 :]
655649 )
656650
657651 # set the index of the DataFrame to the datetime column
658- local_kernel_explnr_df .index = data .index
652+ local_kernel_explnr_df .index = data_horizon .index
659653
660654 if self .spec .model == SupportedModels .AutoTS :
661655 local_kernel_explnr_df .drop (
0 commit comments