2020from ads .opctl import logger
2121
2222from .. import utils
23- from ..const import SUMMARY_METRICS_HORIZON_LIMIT , SupportedMetrics , SupportedModels
23+ from ..const import SUMMARY_METRICS_HORIZON_LIMIT , SupportedMetrics , SupportedModels , SpeedAccuracyMode
2424from ..operator_config import ForecastOperatorConfig , ForecastOperatorSpec
2525from ads .common .decorator .runtime_dependency import runtime_dependency
2626from .forecast_datasets import ForecastDatasets , ForecastOutput
@@ -73,10 +73,10 @@ def generate_report(self):
7373 warnings .simplefilter (action = "ignore" , category = ConvergenceWarning )
7474 import datapane as dp
7575
76- # load data and build models
7776 start_time = time .time ()
7877 result_df = self ._build_model ()
7978 elapsed_time = time .time () - start_time
79+ logger .info ("Building the models completed in %s seconds" , elapsed_time )
8080
8181 # Generate metrics
8282 summary_metrics = None
@@ -574,21 +574,23 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
574574 dict: A dictionary containing the global explanation for each feature in the dataset.
575575 The keys are the feature names and the values are the average absolute SHAP values.
576576 """
577- from shap import KernelExplainer
578-
577+ from shap import PermutationExplainer
578+ exp_start_time = time .time ()
579+ global_ex_time = 0
580+ local_ex_time = 0
579581 for series_id in self .target_columns :
580582 self .series_id = series_id
581583 if self .spec .model == SupportedModels .AutoTS :
582584 self .dataset_cols = (
583585 self .full_data_long .loc [
584- self .full_data_long .series_id == self .series_id
586+ self .full_data_long .series_id == self .category_mapping [ self . series_id ]
585587 ]
586588 .set_index (datetime_col_name )
587589 .columns
588590 )
589591
590592 self .bg_data = self .full_data_long .loc [
591- self .full_data_long .series_id == self .series_id
593+ self .full_data_long .series_id == self .category_mapping [ self . series_id ]
592594 ].set_index (datetime_col_name )
593595
594596 else :
@@ -602,21 +604,21 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
602604 self .bg_data = self .full_data_dict .get (series_id ).set_index (
603605 datetime_col_name
604606 )
605-
606- kernel_explnr = KernelExplainer (
607+ data = self .bg_data [list (self .dataset_cols )][: - self .spec .horizon ][
608+ list (self .dataset_cols )]
609+ ratio = SpeedAccuracyMode .ratio [self .spec .explanations_accuracy_mode ]
610+ logger .info (f"Calculating explanations using { self .spec .explanations_accuracy_mode } mode" )
611+ data_trimmed = data .tail (max (int (len (data ) * ratio ), 100 )).reset_index ()
612+ data_trimmed [datetime_col_name ] = data_trimmed [datetime_col_name ].apply (lambda x : x .timestamp ())
613+ kernel_explnr = PermutationExplainer (
607614 model = explain_predict_fn ,
608- data = self .bg_data [list (self .dataset_cols )][: - self .spec .horizon ][
609- list (self .dataset_cols )
610- ],
615+ masker = data_trimmed ,
611616 keep_index = False
612617 if self .spec .model == SupportedModels .AutoMLX
613618 else True ,
614619 )
615620
616- kernel_explnr_vals = kernel_explnr .shap_values (
617- self .bg_data [: - self .spec .horizon ][list (self .dataset_cols )],
618- nsamples = 50 ,
619- )
621+ kernel_explnr_vals = kernel_explnr .shap_values (data_trimmed )
620622
621623 if not len (kernel_explnr_vals ):
622624 logger .warn (
@@ -625,14 +627,19 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
625627 else :
626628 self .global_explanation [series_id ] = dict (
627629 zip (
628- self . dataset_cols ,
629- np .average (np .absolute (kernel_explnr_vals ), axis = 0 ),
630+ data_trimmed . columns [ 1 :] ,
631+ np .average (np .absolute (kernel_explnr_vals [:, 1 :] ), axis = 0 ),
630632 )
631633 )
634+ exp_end_time = time .time ()
635+ global_ex_time = global_ex_time + exp_end_time - exp_start_time
632636
633637 self .local_explainer (
634638 kernel_explnr , series_id = series_id , datetime_col_name = datetime_col_name
635639 )
640+ local_ex_time = local_ex_time + time .time () - exp_end_time
641+ logger .info ("Global explanations generation completed in %s seconds" , global_ex_time )
642+ logger .info ("Local explanations generation completed in %s seconds" , local_ex_time )
636643
637644 def local_explainer (self , kernel_explainer , series_id , datetime_col_name ) -> None :
638645 """
@@ -644,18 +651,19 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
644651 """
645652 # Get the data for the series ID and select the relevant columns
646653 # data = self.full_data_dict.get(series_id).set_index(datetime_col_name)
647- data = self .bg_data [- self .spec .horizon :][list (self .dataset_cols )]
648-
654+ data_horizon = self .bg_data [- self .spec .horizon :][list (self .dataset_cols )]
655+ data = data_horizon .reset_index ()
656+ data [datetime_col_name ] = data [datetime_col_name ].apply (lambda x : x .timestamp ())
649657 # Generate local SHAP values using the kernel explainer
650- local_kernel_explnr_vals = kernel_explainer .shap_values (data , nsamples = 50 )
658+ local_kernel_explnr_vals = kernel_explainer .shap_values (data )
651659
652660 # Convert the SHAP values into a DataFrame
653661 local_kernel_explnr_df = pd .DataFrame (
654- local_kernel_explnr_vals , columns = self . dataset_cols
662+ local_kernel_explnr_vals [:, 1 :], columns = data . columns [ 1 :]
655663 )
656664
657665 # set the index of the DataFrame to the datetime column
658- local_kernel_explnr_df .index = data .index
666+ local_kernel_explnr_df .index = data_horizon .index
659667
660668 if self .spec .model == SupportedModels .AutoTS :
661669 local_kernel_explnr_df .drop (
0 commit comments