Skip to content

Commit 99d4f5e

Browse files
committed
merging latest changes
2 parents 995ac23 + 4422ce8 commit 99d4f5e

File tree

9 files changed

+577
-601
lines changed

9 files changed

+577
-601
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
8181
DEFAULT_TRIALS = 10
8282
SUMMARY_METRICS_HORIZON_LIMIT = 10
8383
PROPHET_INTERNAL_DATE_COL = "ds"
84+
RENDER_LIMIT = 5000

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88
import numpy as np
99
import pmdarima as pm
10+
from joblib import Parallel, delayed
1011

1112
from ads.opctl import logger
1213

@@ -15,7 +16,7 @@
1516
from ..operator_config import ForecastOperatorConfig
1617
import traceback
1718
from .forecast_datasets import ForecastDatasets, ForecastOutput
18-
from ..const import ForecastOutputColumns, SupportedModels
19+
from ..const import ForecastOutputColumns
1920

2021

2122
class ArimaOperatorModel(ForecastOperatorBaseModel):
@@ -29,32 +30,31 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
2930
self.formatted_global_explanation = None
3031
self.formatted_local_explanation = None
3132

32-
def _build_model(self) -> pd.DataFrame:
33-
full_data_dict = self.datasets.full_data_dict
34-
35-
# Extract the Confidence Interval Width and convert to arima's equivalent - alpha
36-
if self.spec.confidence_interval_width is None:
37-
self.spec.confidence_interval_width = 1 - self.spec.model_kwargs.get(
38-
"alpha", 0.05
39-
)
40-
model_kwargs = self.spec.model_kwargs
41-
model_kwargs["alpha"] = 1 - self.spec.confidence_interval_width
42-
if "error_action" not in model_kwargs.keys():
43-
model_kwargs["error_action"] = "ignore"
33+
def _train_model(self, i, target, df):
34+
"""Trains the ARIMA model for a given target.
4435
45-
models = []
46-
self.datasets.datetime_col = self.spec.datetime_column.name
47-
self.forecast_output = ForecastOutput(
48-
confidence_interval_width=self.spec.confidence_interval_width
49-
)
36+
Parameters
37+
----------
38+
i: int
39+
The index of the target
40+
target: str
41+
The name of the target
42+
df: pd.DataFrame
43+
The dataframe containing the target data
44+
"""
45+
try:
46+
# Extract the Confidence Interval Width and convert to arima's equivalent - alpha
47+
if self.spec.confidence_interval_width is None:
48+
self.spec.confidence_interval_width = 1 - self.spec.model_kwargs.get(
49+
"alpha", 0.05
50+
)
51+
model_kwargs = self.spec.model_kwargs
52+
model_kwargs["alpha"] = 1 - self.spec.confidence_interval_width
53+
if "error_action" not in model_kwargs.keys():
54+
model_kwargs["error_action"] = "ignore"
5055

51-
outputs = dict()
52-
outputs_legacy = []
53-
fitted_values = dict()
54-
actual_values = dict()
55-
dt_columns = dict()
56+
# models = []
5657

57-
for i, (target, df) in enumerate(full_data_dict.items()):
5858
# format the dataframe for this target. Dropping NA on target[df] will remove all future data
5959
le, df_encoded = utils._label_encode_dataframe(
6060
df, no_encode={self.spec.datetime_column.name, target}
@@ -72,34 +72,28 @@ def _build_model(self) -> pd.DataFrame:
7272
target,
7373
self.spec.datetime_column.name,
7474
}
75-
logger.debug(
76-
f"Additional Regressors Detected {list(additional_regressors)}"
77-
)
75+
logger.debug(f"Additional Regressors Detected {list(additional_regressors)}")
7876

7977
# Split data into X and y for arima tune method
8078
y = data_i[target]
8179
X_in = None
8280
if len(additional_regressors):
8381
X_in = data_i.drop(target, axis=1)
8482

85-
model = self.loaded_models[i] if self.loaded_models is not None else None
86-
if model is None:
87-
# Build and fit model
88-
model = pm.auto_arima(y=y, X=X_in, **self.spec.model_kwargs)
83+
# Build and fit model
84+
model = pm.auto_arima(y=y, X=X_in, **self.spec.model_kwargs)
8985

90-
fitted_values[target] = model.predict_in_sample(X=X_in)
91-
actual_values[target] = y
92-
actual_values[target].index = pd.to_datetime(y.index)
86+
self.fitted_values[target] = model.predict_in_sample(X=X_in)
87+
self.actual_values[target] = y
88+
self.actual_values[target].index = pd.to_datetime(y.index)
9389

9490
# Build future dataframe
9591
start_date = y.index.values[-1]
9692
n_periods = self.spec.horizon
9793
if len(additional_regressors):
9894
X = df_clean[df_clean[target].isnull()].drop(target, axis=1)
9995
else:
100-
X = pd.date_range(
101-
start=start_date, periods=n_periods, freq=self.spec.freq
102-
)
96+
X = pd.date_range(start=start_date, periods=n_periods, freq=self.spec.freq)
10397

10498
# Predict and format forecast
10599
yhat, conf_int = model.predict(
@@ -110,7 +104,7 @@ def _build_model(self) -> pd.DataFrame:
110104
)
111105
yhat_clean = pd.DataFrame(yhat, index=yhat.index, columns=["yhat"])
112106

113-
dt_columns[target] = df_encoded[self.spec.datetime_column.name]
107+
self.dt_columns[target] = df_encoded[self.spec.datetime_column.name]
114108
conf_int_clean = pd.DataFrame(
115109
conf_int, index=yhat.index, columns=["yhat_lower", "yhat_upper"]
116110
)
@@ -119,25 +113,42 @@ def _build_model(self) -> pd.DataFrame:
119113
logger.debug(forecast[["yhat", "yhat_lower", "yhat_upper"]].tail())
120114

121115
# Collect all outputs
122-
if self.loaded_models is None:
123-
models.append(model)
124-
outputs_legacy.append(
116+
# models.append(model)
117+
self.outputs_legacy.append(
125118
forecast.reset_index().rename(columns={"index": "ds"})
126119
)
127-
outputs[target] = forecast
128-
129-
params = vars(model).copy()
130-
for param in ['arima_res', 'endog_index_']:
131-
if param in params:
132-
params.pop(param)
133-
self.model_parameters[target] = {
134-
"framework": SupportedModels.Arima,
135-
**params,
136-
}
120+
self.outputs[target] = forecast
121+
122+
self.models_dict[target] = model
137123

138-
self.models = self.loaded_models if self.loaded_models is not None else models
124+
logger.debug("===========Done===========")
125+
except Exception as e:
126+
self.errors_dict[target] = {"model_name": self.spec.model, "error": str(e)}
127+
128+
def _build_model(self) -> pd.DataFrame:
129+
full_data_dict = self.datasets.full_data_dict
130+
131+
self.datasets.datetime_col = self.spec.datetime_column.name
132+
self.forecast_output = ForecastOutput(
133+
confidence_interval_width=self.spec.confidence_interval_width
134+
)
135+
136+
self.outputs = dict()
137+
self.outputs_legacy = []
138+
self.fitted_values = dict()
139+
self.actual_values = dict()
140+
self.dt_columns = dict()
141+
self.models_dict = dict()
142+
self.errors_dict = dict()
143+
144+
Parallel(n_jobs=-1, require="sharedmem")(
145+
delayed(ArimaOperatorModel._train_model)(self, i, target, df)
146+
for self, (i, (target, df)) in zip(
147+
[self] * len(full_data_dict), enumerate(full_data_dict.items())
148+
)
149+
)
139150

140-
logger.debug("===========Done===========")
151+
self.models = [self.models_dict[target] for target in self.target_columns]
141152

142153
# Merge the outputs from each model into 1 df with all outputs by target and category
143154
col = self.original_target_column
@@ -146,15 +157,15 @@ def _build_model(self) -> pd.DataFrame:
146157
yhat_lower_name = ForecastOutputColumns.LOWER_BOUND
147158
for cat in self.categories:
148159
output_i = pd.DataFrame()
149-
output_i["Date"] = dt_columns[f"{col}_{cat}"]
160+
output_i["Date"] = self.dt_columns[f"{col}_{cat}"]
150161
output_i["Series"] = cat
151162
output_i = output_i.set_index("Date")
152163

153-
output_i["input_value"] = actual_values[f"{col}_{cat}"]
154-
output_i["fitted_value"] = fitted_values[f"{col}_{cat}"]
155-
output_i["forecast_value"] = outputs[f"{col}_{cat}"]["yhat"]
156-
output_i[yhat_upper_name] = outputs[f"{col}_{cat}"]["yhat_upper"]
157-
output_i[yhat_lower_name] = outputs[f"{col}_{cat}"]["yhat_lower"]
164+
output_i["input_value"] = self.actual_values[f"{col}_{cat}"]
165+
output_i["fitted_value"] = self.fitted_values[f"{col}_{cat}"]
166+
output_i["forecast_value"] = self.outputs[f"{col}_{cat}"]["yhat"]
167+
output_i[yhat_upper_name] = self.outputs[f"{col}_{cat}"]["yhat_upper"]
168+
output_i[yhat_lower_name] = self.outputs[f"{col}_{cat}"]["yhat_lower"]
158169

159170
output_i = output_i.reset_index(drop=False)
160171
output_col = pd.concat([output_col, output_i])
@@ -196,7 +207,7 @@ def _generate_report(self):
196207
global_explanation_df = pd.DataFrame(self.global_explanation)
197208

198209
self.formatted_global_explanation = (
199-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
210+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
200211
)
201212

202213
# Create a markdown section for the global explainability
@@ -264,7 +275,7 @@ def _custom_predict_arima(self, data):
264275
265276
"""
266277
date_col = self.spec.datetime_column.name
267-
data[date_col] = pd.to_datetime(data[date_col], unit='s')
278+
data[date_col] = pd.to_datetime(data[date_col], unit="s")
268279
data = data.set_index(date_col)
269280
# Get the index of the current series id
270281
series_index = self.target_columns.index(self.series_id)

0 commit comments

Comments
 (0)