Skip to content

Commit 4523a0d

Browse files
committed
Add post-processing step for forecast clipping
1 parent ef7c14d commit 4523a0d

File tree

10 files changed

+119
-15
lines changed

10 files changed

+119
-15
lines changed

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def _build_model(self) -> pd.DataFrame:
151151
horizon=self.spec.horizon,
152152
target_column=self.original_target_column,
153153
dt_column=self.spec.datetime_column.name,
154+
postprocessing=self.spec.postprocessing,
154155
)
155156

156157
Parallel(n_jobs=-1, require="sharedmem")(

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _build_model(self) -> pd.DataFrame:
100100
horizon=self.spec.horizon,
101101
target_column=self.original_target_column,
102102
dt_column=self.spec.datetime_column.name,
103+
postprocessing=self.spec.postprocessing,
103104
)
104105

105106
# Clean up kwargs for pass through

ads/opctl/operator/lowcode/forecast/model/autots.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _build_model(self) -> pd.DataFrame:
5454
horizon=self.spec.horizon,
5555
target_column=self.original_target_column,
5656
dt_column=self.spec.datetime_column.name,
57+
postprocessing=self.spec.postprocessing,
5758
)
5859
try:
5960
model = self.loaded_models if self.loaded_models is not None else None

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from typing import Dict, List
77

8+
import numpy as np
89
import pandas as pd
910

1011
from ads.opctl import logger
@@ -18,13 +19,15 @@
1819
get_frequency_of_datetime,
1920
)
2021

21-
from ..const import ForecastOutputColumns, SupportedModels, TROUBLESHOOTING_GUIDE
22-
from ..operator_config import ForecastOperatorConfig
22+
from ..const import TROUBLESHOOTING_GUIDE, ForecastOutputColumns, SupportedModels
23+
from ..operator_config import ForecastOperatorConfig, PostprocessingSteps
2324

2425

2526
class HistoricalData(AbstractData):
2627
def __init__(self, spec, historical_data=None, subset=None):
27-
super().__init__(spec=spec, name="historical_data", data=historical_data, subset=subset)
28+
super().__init__(
29+
spec=spec, name="historical_data", data=historical_data, subset=subset
30+
)
2831
self.subset = subset
2932

3033
def _ingest_data(self, spec):
@@ -49,15 +52,19 @@ def _verify_dt_col(self, spec):
4952
f"{SupportedModels.AutoMLX} requires data with a frequency of at least one hour. Please try using a different model,"
5053
" or select the 'auto' option."
5154
)
52-
raise InvalidParameterError(f"{message}"
53-
f"\nPlease refer to the troubleshooting guide at {TROUBLESHOOTING_GUIDE} for resolution steps.")
55+
raise InvalidParameterError(
56+
f"{message}"
57+
f"\nPlease refer to the troubleshooting guide at {TROUBLESHOOTING_GUIDE} for resolution steps."
58+
)
5459

5560

5661
class AdditionalData(AbstractData):
5762
def __init__(self, spec, historical_data, additional_data=None, subset=None):
5863
self.subset = subset
5964
if additional_data is not None:
60-
super().__init__(spec=spec, name="additional_data", data=additional_data, subset=subset)
65+
super().__init__(
66+
spec=spec, name="additional_data", data=additional_data, subset=subset
67+
)
6168
self.additional_regressors = list(self.data.columns)
6269
elif spec.additional_data is not None:
6370
super().__init__(spec=spec, name="additional_data", subset=subset)
@@ -70,7 +77,7 @@ def __init__(self, spec, historical_data, additional_data=None, subset=None):
7077
)
7178
elif historical_data.get_max_time() != add_dates[-(spec.horizon + 1)]:
7279
raise DataMismatchError(
73-
f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on {historical_data.get_max_time()}. The additonal data horizon starts after {add_dates[-(spec.horizon+1)]}. These should be the same date."
80+
f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on {historical_data.get_max_time()}. The additonal data horizon starts after {add_dates[-(spec.horizon + 1)]}. These should be the same date."
7481
f"\nPlease refer to the troubleshooting guide at {TROUBLESHOOTING_GUIDE} for resolution steps."
7582
)
7683
else:
@@ -150,7 +157,9 @@ def __init__(
150157
self._datetime_column_name = config.spec.datetime_column.name
151158
self._target_col = config.spec.target_column
152159
if historical_data is not None:
153-
self.historical_data = HistoricalData(config.spec, historical_data, subset=subset)
160+
self.historical_data = HistoricalData(
161+
config.spec, historical_data, subset=subset
162+
)
154163
self.additional_data = AdditionalData(
155164
config.spec, self.historical_data, additional_data, subset=subset
156165
)
@@ -276,6 +285,7 @@ def __init__(
276285
horizon: int,
277286
target_column: str,
278287
dt_column: str,
288+
postprocessing: PostprocessingSteps,
279289
):
280290
"""Forecast Output contains all the details required to generate the forecast.csv output file.
281291
@@ -285,12 +295,14 @@ def __init__(
285295
horizon: int length of horizon
286296
target_column: str the name of the original target column
287297
dt_column: the name of the original datetime column
298+
postprocessing: postprocessing steps to be executed
288299
"""
289300
self.series_id_map = {}
290301
self._set_ci_column_names(confidence_interval_width)
291302
self.horizon = horizon
292303
self.target_column_name = target_column
293304
self.dt_column_name = dt_column
305+
self.postprocessing = postprocessing
294306

295307
def add_series_id(
296308
self,
@@ -337,6 +349,12 @@ def populate_series_output(
337349
--------
338350
None
339351
"""
352+
min_threshold, max_threshold = (
353+
self.postprocessing.set_min_forecast,
354+
self.postprocessing.set_max_forecast,
355+
)
356+
if min_threshold is not None or max_threshold is not None:
357+
np.clip(forecast_val, min_threshold, max_threshold, out=forecast_val)
340358
try:
341359
output_i = self.series_id_map[series_id]
342360
except KeyError as e:
@@ -422,9 +440,9 @@ def _set_ci_column_names(self, confidence_interval_width):
422440

423441
def _check_forecast_format(self, forecast):
424442
assert isinstance(forecast, pd.DataFrame)
425-
assert (
426-
len(forecast.columns) == 7
427-
), f"Expected just 7 columns, but got: {forecast.columns}"
443+
assert len(forecast.columns) == 7, (
444+
f"Expected just 7 columns, but got: {forecast.columns}"
445+
)
428446
assert ForecastOutputColumns.DATE in forecast.columns
429447
assert ForecastOutputColumns.SERIES in forecast.columns
430448
assert ForecastOutputColumns.INPUT_VALUE in forecast.columns
@@ -506,16 +524,30 @@ def set_errors_dict(self, errors_dict: Dict):
506524
def get_errors_dict(self):
507525
return getattr(self, "errors_dict", None)
508526

509-
def merge(self, other: 'ForecastResults'):
527+
def merge(self, other: "ForecastResults"):
510528
"""Merge another ForecastResults object into this one."""
511529
# Merge DataFrames if they exist, else just set
512530
for attr in [
513-
'forecast', 'metrics', 'test_metrics', 'local_explanations', 'global_explanations', 'model_parameters', 'models', 'errors_dict']:
531+
"forecast",
532+
"metrics",
533+
"test_metrics",
534+
"local_explanations",
535+
"global_explanations",
536+
"model_parameters",
537+
"models",
538+
"errors_dict",
539+
]:
514540
val_self = getattr(self, attr, None)
515541
val_other = getattr(other, attr, None)
516542
if val_self is not None and val_other is not None:
517-
if isinstance(val_self, pd.DataFrame) and isinstance(val_other, pd.DataFrame):
518-
setattr(self, attr, pd.concat([val_self, val_other], ignore_index=True, axis=0))
543+
if isinstance(val_self, pd.DataFrame) and isinstance(
544+
val_other, pd.DataFrame
545+
):
546+
setattr(
547+
self,
548+
attr,
549+
pd.concat([val_self, val_other], ignore_index=True, axis=0),
550+
)
519551
elif isinstance(val_self, dict) and isinstance(val_other, dict):
520552
val_self.update(val_other)
521553
setattr(self, attr, val_self)

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def _build_model(self) -> pd.DataFrame:
182182
horizon=self.spec.horizon,
183183
target_column=self.original_target_column,
184184
dt_column=self.date_col,
185+
postprocessing=self.spec.postprocessing,
185186
)
186187
self._train_model(data_train, data_test, model_kwargs)
187188
return self.forecast_output.get_forecast_long()

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def _build_model(self) -> pd.DataFrame:
234234
horizon=self.spec.horizon,
235235
target_column=self.original_target_column,
236236
dt_column=self.spec.datetime_column.name,
237+
postprocessing=self.spec.postprocessing,
237238
)
238239

239240
for i, (s_id, df) in enumerate(full_data_dict.items()):

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def _build_model(self) -> pd.DataFrame:
198198
horizon=self.spec.horizon,
199199
target_column=self.original_target_column,
200200
dt_column=self.spec.datetime_column.name,
201+
postprocessing=self.spec.postprocessing,
201202
)
202203

203204
Parallel(n_jobs=-1, require="sharedmem")(

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ class PreprocessingSteps(DataClassSerializable):
7676
outlier_treatment: bool = True
7777

7878

79+
@dataclass(repr=True)
80+
class PostprocessingSteps(DataClassSerializable):
81+
"""Class representing postprocessing steps for operator."""
82+
83+
set_min_forecast: int = None
84+
set_max_forecast: int = None
85+
86+
7987
@dataclass(repr=True)
8088
class DataPreprocessor(DataClassSerializable):
8189
"""Class representing operator specification preprocessing details."""
@@ -110,6 +118,7 @@ class ForecastOperatorSpec(DataClassSerializable):
110118
local_explanation_filename: str = None
111119
target_column: str = None
112120
preprocessing: DataPreprocessor = field(default_factory=DataPreprocessor)
121+
postprocessing: PostprocessingSteps = field(default_factory=PostprocessingSteps)
113122
datetime_column: DateTimeColumn = field(default_factory=DateTimeColumn)
114123
target_category_columns: List[str] = field(default_factory=list)
115124
generate_report: bool = None
@@ -146,6 +155,11 @@ def __post_init__(self):
146155
if self.preprocessing is not None
147156
else DataPreprocessor(enabled=True)
148157
)
158+
self.postprocessing = (
159+
self.postprocessing
160+
if self.postprocessing is not None
161+
else PostprocessingSteps()
162+
)
149163
# For Report Generation. When user doesn't specify defaults to True
150164
self.generate_report = (
151165
self.generate_report if self.generate_report is not None else True

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,21 @@ spec:
329329
required: false
330330
default: false
331331

332+
postprocessing:
333+
type: dict
334+
required: false
335+
schema:
336+
set_min_forecast:
337+
type: integer
338+
required: false
339+
meta:
340+
description: "This can be used to define the minimum forecast in the output."
341+
set_max_forecast:
342+
type: integer
343+
required: false
344+
meta:
345+
description: "This can be used to define the maximum forecast in the output."
346+
332347
generate_explanations:
333348
type: boolean
334349
required: false

tests/operators/forecast/test_datasets.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,5 +413,42 @@ def run_operator(
413413
# generate_train_metrics = True
414414

415415

416+
def test_postprocessing_clipping():
417+
"""Tests the postprocessing clipping of forecast values."""
418+
df = pd.DataFrame(
419+
{
420+
"Date": pd.to_datetime(pd.date_range("2023-01-01", periods=20, freq="D")),
421+
"Y": range(0, 40, 2),
422+
}
423+
)
424+
425+
min_clip = 40
426+
427+
max_clip = 42
428+
429+
with tempfile.TemporaryDirectory() as tmpdirname:
430+
output_data_path = f"{tmpdirname}/results"
431+
yaml_i = deepcopy(TEMPLATE_YAML)
432+
yaml_i["spec"]["model"] = "prophet"
433+
yaml_i["spec"]["historical_data"].pop("url")
434+
yaml_i["spec"]["historical_data"]["data"] = df
435+
yaml_i["spec"]["target_column"] = "Y"
436+
yaml_i["spec"]["datetime_column"]["name"] = DATETIME_COL
437+
yaml_i["spec"]["horizon"] = 5
438+
yaml_i["spec"]["output_directory"]["url"] = output_data_path
439+
yaml_i["spec"]["postprocessing"] = {
440+
"set_min_forecast": min_clip,
441+
"set_max_forecast": max_clip,
442+
}
443+
444+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
445+
forecast_operate(operator_config)
446+
447+
forecast_df = pd.read_csv(f"{output_data_path}/forecast.csv")
448+
449+
assert forecast_df["forecast_value"].min() >= min_clip
450+
assert forecast_df["forecast_value"].max() <= max_clip
451+
452+
416453
if __name__ == "__main__":
417454
pass

0 commit comments

Comments
 (0)