Skip to content

Commit 7f5d5ab

Browse files
authored
Subsampling data for faster plot rendering (#517)
2 parents b7d54db + 0174aeb commit 7f5d5ab

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
8181
DEFAULT_TRIALS = 10
8282
SUMMARY_METRICS_HORIZON_LIMIT = 10
8383
PROPHET_INTERNAL_DATE_COL = "ds"
84+
RENDER_LIMIT = 5000

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,7 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
598598
data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(lambda x: x.timestamp())
599599
kernel_explnr = PermutationExplainer(
600600
model=explain_predict_fn,
601-
masker=data_trimmed,
602-
keep_index=False
603-
if self.spec.model == SupportedModels.AutoMLX
604-
else True,
601+
masker=data_trimmed
605602
)
606603

607604
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed)

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ads.dataset.label_encoder import DataFrameLabelEncoder
2929
from ads.opctl import logger
3030

31-
from .const import SupportedMetrics, SupportedModels
31+
from .const import SupportedMetrics, SupportedModels, RENDER_LIMIT
3232
from .errors import ForecastInputDataError, ForecastSchemaYamlError
3333
from .operator_config import ForecastOperatorSpec, ForecastOperatorConfig
3434

@@ -417,6 +417,23 @@ def get_forecast_plots(
417417
def plot_forecast_plotly(idx, col):
418418
fig = go.Figure()
419419
forecast_i = forecast_output.get_target_category(col)
420+
actual_length = len(forecast_i)
421+
if actual_length > RENDER_LIMIT:
422+
forecast_i = forecast_i.tail(RENDER_LIMIT)
423+
text = f"<i>To improve rendering speed, subsampled the data from {actual_length}" \
424+
f" rows to {RENDER_LIMIT} rows for this plot.</i>"
425+
fig.update_layout(
426+
annotations=[
427+
go.layout.Annotation(
428+
x=0.01,
429+
y=1.1,
430+
xref="paper",
431+
yref="paper",
432+
text=text,
433+
showarrow=False
434+
)
435+
]
436+
)
420437
upper_bound = forecast_output.upper_bound_name
421438
lower_bound = forecast_output.lower_bound_name
422439
if upper_bound is not None and lower_bound is not None:

0 commit comments

Comments
 (0)