@@ -578,36 +578,22 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
578578 exp_start_time = time .time ()
579579 global_ex_time = 0
580580 local_ex_time = 0
581+ logger .info (f"Calculating explanations using { self .spec .explanations_accuracy_mode } mode" )
581582 for series_id in self .target_columns :
582583 self .series_id = series_id
583- if self .spec .model == SupportedModels .AutoTS :
584- self .dataset_cols = (
585- self .full_data_long .loc [
586- self .full_data_long .series_id == self .category_mapping [self .series_id ]
587- ]
588- .set_index (datetime_col_name )
589- .columns
590- )
591-
592- self .bg_data = self .full_data_long .loc [
593- self .full_data_long .series_id == self .category_mapping [self .series_id ]
594- ].set_index (datetime_col_name )
595-
596- else :
597- self .dataset_cols = (
598- self .full_data_dict .get (series_id )
599- .set_index (datetime_col_name )
600- .drop (series_id , axis = 1 )
601- .columns
602- )
584+ self .dataset_cols = (
585+ self .full_data_dict .get (series_id )
586+ .set_index (datetime_col_name )
587+ .drop (series_id , axis = 1 )
588+ .columns
589+ )
603590
604- self .bg_data = self .full_data_dict .get (series_id ).set_index (
605- datetime_col_name
606- )
591+ self .bg_data = self .full_data_dict .get (series_id ).set_index (
592+ datetime_col_name
593+ )
607594 data = self .bg_data [list (self .dataset_cols )][: - self .spec .horizon ][
608595 list (self .dataset_cols )]
609596 ratio = SpeedAccuracyMode .ratio [self .spec .explanations_accuracy_mode ]
610- logger .info (f"Calculating explanations using { self .spec .explanations_accuracy_mode } mode" )
611597 data_trimmed = data .tail (max (int (len (data ) * ratio ), 100 )).reset_index ()
612598 data_trimmed [datetime_col_name ] = data_trimmed [datetime_col_name ].apply (lambda x : x .timestamp ())
613599 kernel_explnr = PermutationExplainer (
0 commit comments