Skip to content

Commit 4629ba5

Browse files
committed
Faster explainability
1 parent f5d39a9 commit 4629ba5

File tree

8 files changed

+64
-26
lines changed

8 files changed

+64
-26
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from ads.common.extended_enum import ExtendedEnumMeta
88

9-
109
class SupportedModels(str, metaclass=ExtendedEnumMeta):
1110
"""Supported forecast models."""
1211

@@ -18,6 +17,18 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
1817
Auto = "auto"
1918

2019

20+
class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
21+
"""
22+
Enum representing different modes based on time taken and accuracy for explainability.
23+
"""
24+
HIGH_ACCURACY = "HIGH_ACCURACY"
25+
BALANCED = "BALANCED"
26+
FAST_APPROXIMATE = "FAST_APPROXIMATE"
27+
ratio = dict()
28+
ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
29+
ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
30+
ratio[FAST_APPROXIMATE] = 0 # constant
31+
2132
class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
2233
"""Supported forecast metrics."""
2334

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def _custom_predict_arima(self, data):
251251
array-like: The predicted values.
252252
253253
"""
254+
date_col = self.spec.datetime_column.name
255+
data[date_col] = pd.to_datetime(data[date_col], unit='s')
256+
data = data.set_index(date_col)
254257
# Get the index of the current series id
255258
series_index = self.target_columns.index(self.series_id)
256259

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def _custom_predict_automlx(self, data):
316316
-------
317317
numpy.ndarray: The predicted future values of the time series.
318318
"""
319-
temp = 0
320319
data_temp = pd.DataFrame(
321320
data,
322321
columns=[col for col in self.dataset_cols],

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ads.opctl import logger
2121

2222
from .. import utils
23-
from ..const import SUMMARY_METRICS_HORIZON_LIMIT, SupportedMetrics, SupportedModels
23+
from ..const import SUMMARY_METRICS_HORIZON_LIMIT, SupportedMetrics, SupportedModels, SpeedAccuracyMode
2424
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
2525
from ads.common.decorator.runtime_dependency import runtime_dependency
2626
from .forecast_datasets import ForecastDatasets, ForecastOutput
@@ -73,10 +73,10 @@ def generate_report(self):
7373
warnings.simplefilter(action="ignore", category=ConvergenceWarning)
7474
import datapane as dp
7575

76-
# load data and build models
7776
start_time = time.time()
7877
result_df = self._build_model()
7978
elapsed_time = time.time() - start_time
79+
logger.info("Building the models completed in %s seconds", elapsed_time)
8080

8181
# Generate metrics
8282
summary_metrics = None
@@ -574,21 +574,23 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
574574
dict: A dictionary containing the global explanation for each feature in the dataset.
575575
The keys are the feature names and the values are the average absolute SHAP values.
576576
"""
577-
from shap import KernelExplainer
578-
577+
from shap import PermutationExplainer
578+
exp_start_time = time.time()
579+
global_ex_time = 0
580+
local_ex_time = 0
579581
for series_id in self.target_columns:
580582
self.series_id = series_id
581583
if self.spec.model == SupportedModels.AutoTS:
582584
self.dataset_cols = (
583585
self.full_data_long.loc[
584-
self.full_data_long.series_id == self.series_id
586+
self.full_data_long.series_id == self.category_mapping[self.series_id]
585587
]
586588
.set_index(datetime_col_name)
587589
.columns
588590
)
589591

590592
self.bg_data = self.full_data_long.loc[
591-
self.full_data_long.series_id == self.series_id
593+
self.full_data_long.series_id == self.category_mapping[self.series_id]
592594
].set_index(datetime_col_name)
593595

594596
else:
@@ -602,21 +604,21 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
602604
self.bg_data = self.full_data_dict.get(series_id).set_index(
603605
datetime_col_name
604606
)
605-
606-
kernel_explnr = KernelExplainer(
607+
data = self.bg_data[list(self.dataset_cols)][: -self.spec.horizon][
608+
list(self.dataset_cols)]
609+
ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
610+
logger.info(f"Calculating explanations using {self.spec.explanations_accuracy_mode} mode")
611+
data_trimmed = data.tail(max(int(len(data) * ratio), 100)).reset_index()
612+
data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(lambda x: x.timestamp())
613+
kernel_explnr = PermutationExplainer(
607614
model=explain_predict_fn,
608-
data=self.bg_data[list(self.dataset_cols)][: -self.spec.horizon][
609-
list(self.dataset_cols)
610-
],
615+
masker=data_trimmed,
611616
keep_index=False
612617
if self.spec.model == SupportedModels.AutoMLX
613618
else True,
614619
)
615620

616-
kernel_explnr_vals = kernel_explnr.shap_values(
617-
self.bg_data[: -self.spec.horizon][list(self.dataset_cols)],
618-
nsamples=50,
619-
)
621+
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed)
620622

621623
if not len(kernel_explnr_vals):
622624
logger.warn(
@@ -625,14 +627,19 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
625627
else:
626628
self.global_explanation[series_id] = dict(
627629
zip(
628-
self.dataset_cols,
629-
np.average(np.absolute(kernel_explnr_vals), axis=0),
630+
data_trimmed.columns[1:],
631+
np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0),
630632
)
631633
)
634+
exp_end_time = time.time()
635+
global_ex_time = global_ex_time + exp_end_time - exp_start_time
632636

633637
self.local_explainer(
634638
kernel_explnr, series_id=series_id, datetime_col_name=datetime_col_name
635639
)
640+
local_ex_time = local_ex_time + time.time() - exp_end_time
641+
logger.info("Global explanations generation completed in %s seconds", global_ex_time)
642+
logger.info("Local explanations generation completed in %s seconds", local_ex_time)
636643

637644
def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> None:
638645
"""
@@ -644,18 +651,19 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
644651
"""
645652
# Get the data for the series ID and select the relevant columns
646653
# data = self.full_data_dict.get(series_id).set_index(datetime_col_name)
647-
data = self.bg_data[-self.spec.horizon :][list(self.dataset_cols)]
648-
654+
data_horizon = self.bg_data[-self.spec.horizon:][list(self.dataset_cols)]
655+
data = data_horizon.reset_index()
656+
data[datetime_col_name] = data[datetime_col_name].apply(lambda x: x.timestamp())
649657
# Generate local SHAP values using the kernel explainer
650-
local_kernel_explnr_vals = kernel_explainer.shap_values(data, nsamples=50)
658+
local_kernel_explnr_vals = kernel_explainer.shap_values(data)
651659

652660
# Convert the SHAP values into a DataFrame
653661
local_kernel_explnr_df = pd.DataFrame(
654-
local_kernel_explnr_vals, columns=self.dataset_cols
662+
local_kernel_explnr_vals[:, 1:], columns=data.columns[1:]
655663
)
656664

657665
# set the index of the DataFrame to the datetime column
658-
local_kernel_explnr_df.index = data.index
666+
local_kernel_explnr_df.index = data_horizon.index
659667

660668
if self.spec.model == SupportedModels.AutoTS:
661669
local_kernel_explnr_df.drop(

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pandas as pd
1313
from ..const import ForecastOutputColumns, PROPHET_INTERNAL_DATE_COL
1414
from pandas.api.types import is_datetime64_any_dtype, is_string_dtype, is_numeric_dtype
15-
15+
import time
1616

1717
class ForecastDatasets:
1818
def __init__(self, config: ForecastOperatorConfig):
@@ -36,14 +36,19 @@ def __init__(self, config: ForecastOperatorConfig):
3636
def _load_data(self, spec):
3737
"""Loads forecasting input data."""
3838

39+
loading_start_time = time.time()
3940
raw_data = utils._load_data(
4041
filename=spec.historical_data.url,
4142
format=spec.historical_data.format,
4243
columns=spec.historical_data.columns,
4344
)
45+
loading_end_time = time.time()
46+
logger.info("Loading the data completed in %s seconds", loading_end_time - loading_start_time)
4447
self.original_user_data = raw_data.copy()
4548
data_transformer = Transformations(raw_data, spec)
4649
data = data_transformer.run()
50+
transformation_end_time = time.time()
51+
logger.info("Transformations are completed in %s seconds", transformation_end_time - loading_end_time)
4752
try:
4853
spec.freq = utils.get_frequency_of_datetime(data, spec)
4954
except TypeError as e:

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def _generate_report(self):
378378
)
379379

380380
def _custom_predict_prophet(self, data):
381+
data[PROPHET_INTERNAL_DATE_COL] = pd.to_datetime(data[PROPHET_INTERNAL_DATE_COL], unit='s')
381382
return self.models[self.target_columns.index(self.series_id)].predict(
382383
data.reset_index()
383384
)["yhat"]

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ads.opctl.operator.common.utils import _load_yaml_from_uri
1313
from ads.opctl.operator.common.operator_config import OperatorConfig
1414

15-
from .const import SupportedMetrics
15+
from .const import SupportedMetrics, SpeedAccuracyMode
1616
from .const import SupportedModels
1717

1818
@dataclass(repr=True)
@@ -88,6 +88,7 @@ class ForecastOperatorSpec(DataClassSerializable):
8888
generate_report: bool = None
8989
generate_metrics: bool = None
9090
generate_explanations: bool = None
91+
explanations_accuracy_mode: str = None
9192
horizon: int = None
9293
freq: str = None
9394
model: str = None
@@ -119,6 +120,7 @@ def __post_init__(self):
119120
if self.generate_explanations is not None
120121
else False
121122
)
123+
self.explanations_accuracy_mode = self.explanations_accuracy_mode or SpeedAccuracyMode.FAST_APPROXIMATE
122124
self.report_theme = self.report_theme or "light"
123125
self.metrics_filename = self.metrics_filename or "metrics.csv"
124126
self.test_metrics_filename = self.test_metrics_filename or "test_metrics.csv"

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,15 @@ spec:
234234
meta:
235235
description: "Explainability, both local and global, can be disabled using this flag. Defaults to false."
236236

237+
explanations_accuracy_mode:
238+
type: string
239+
required: false
240+
default: FAST_APPROXIMATE
241+
allowed:
242+
- HIGH_ACCURACY
243+
- BALANCED
244+
- FAST_APPROXIMATE
245+
237246
generate_report:
238247
type: boolean
239248
required: false

0 commit comments

Comments
 (0)