Skip to content

Commit 4afe3c7

Browse files
authored
Model parameters and pickle generation, reloading of model pickle (#512)
2 parents 4422ce8 + 338ac6b commit 4afe3c7

File tree

9 files changed

+567
-349
lines changed

9 files changed

+567
-349
lines changed

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..operator_config import ForecastOperatorConfig
1717
import traceback
1818
from .forecast_datasets import ForecastDatasets, ForecastOutput
19-
from ..const import ForecastOutputColumns
19+
from ..const import ForecastOutputColumns, SupportedModels
2020

2121

2222
class ArimaOperatorModel(ForecastOperatorBaseModel):
@@ -80,8 +80,10 @@ def _train_model(self, i, target, df):
8080
if len(additional_regressors):
8181
X_in = data_i.drop(target, axis=1)
8282

83-
# Build and fit model
84-
model = pm.auto_arima(y=y, X=X_in, **self.spec.model_kwargs)
83+
model = self.loaded_models[target] if self.loaded_models is not None else None
84+
if model is None:
85+
# Build and fit model
86+
model = pm.auto_arima(y=y, X=X_in, **self.spec.model_kwargs)
8587

8688
self.fitted_values[target] = model.predict_in_sample(X=X_in)
8789
self.actual_values[target] = y
@@ -119,7 +121,17 @@ def _train_model(self, i, target, df):
119121
)
120122
self.outputs[target] = forecast
121123

122-
self.models_dict[target] = model
124+
if self.loaded_models is None:
125+
self.models[target] = model
126+
127+
params = vars(model).copy()
128+
for param in ['arima_res_', 'endog_index_']:
129+
if param in params:
130+
params.pop(param)
131+
self.model_parameters[target] = {
132+
"framework": SupportedModels.Arima,
133+
**params,
134+
}
123135

124136
logger.debug("===========Done===========")
125137
except Exception as e:
@@ -133,12 +145,12 @@ def _build_model(self) -> pd.DataFrame:
133145
confidence_interval_width=self.spec.confidence_interval_width
134146
)
135147

148+
self.models = dict()
136149
self.outputs = dict()
137150
self.outputs_legacy = []
138151
self.fitted_values = dict()
139152
self.actual_values = dict()
140153
self.dt_columns = dict()
141-
self.models_dict = dict()
142154
self.errors_dict = dict()
143155

144156
Parallel(n_jobs=-1, require="sharedmem")(
@@ -148,13 +160,15 @@ def _build_model(self) -> pd.DataFrame:
148160
)
149161
)
150162

151-
self.models = [self.models_dict[target] for target in self.target_columns]
163+
if self.loaded_models is not None:
164+
self.models = self.loaded_models
152165

153166
# Merge the outputs from each model into 1 df with all outputs by target and category
154167
col = self.original_target_column
155168
output_col = pd.DataFrame()
156169
yhat_upper_name = ForecastOutputColumns.UPPER_BOUND
157170
yhat_lower_name = ForecastOutputColumns.LOWER_BOUND
171+
158172
for cat in self.categories:
159173
output_i = pd.DataFrame()
160174
output_i["Date"] = self.dt_columns[f"{col}_{cat}"]
@@ -183,8 +197,8 @@ def _generate_report(self):
183197

184198
sec5_text = dp.Text(f"## ARIMA Model Parameters")
185199
blocks = [
186-
dp.HTML(m.summary().as_html(), label=self.target_columns[i])
187-
for i, m in enumerate(self.models)
200+
dp.HTML(m.summary().as_html(), label=target)
201+
for i, (target, m) in enumerate(self.models.items())
188202
]
189203
sec5 = dp.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
190204
all_sections = [sec5_text, sec5]
@@ -196,7 +210,6 @@ def _generate_report(self):
196210
datetime_col_name=self.spec.datetime_column.name,
197211
explain_predict_fn=self._custom_predict_arima,
198212
)
199-
200213
# Create a markdown text block for the global explanation section
201214
global_explanation_text = dp.Text(
202215
f"## Global Explanation of Models \n "
@@ -277,10 +290,8 @@ def _custom_predict_arima(self, data):
277290
date_col = self.spec.datetime_column.name
278291
data[date_col] = pd.to_datetime(data[date_col], unit="s")
279292
data = data.set_index(date_col)
280-
# Get the index of the current series id
281-
series_index = self.target_columns.index(self.series_id)
282293

283294
# Use the ARIMA model to predict the values
284-
predictions = self.models[series_index].predict(X=data, n_periods=len(data))
295+
predictions = self.models[self.series_id].predict(X=data, n_periods=len(data))
285296

286297
return predictions

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ads.common.decorator.runtime_dependency import runtime_dependency
1111
from ads.opctl.operator.lowcode.forecast.const import (
1212
AUTOMLX_METRIC_MAP,
13-
ForecastOutputColumns,
13+
ForecastOutputColumns, SupportedModels,
1414
)
1515
from ads.opctl import logger
1616

@@ -60,7 +60,6 @@ def _build_model(self) -> pd.DataFrame:
6060
models = dict()
6161
outputs = dict()
6262
outputs_legacy = dict()
63-
selected_models = dict()
6463
date_column = self.spec.datetime_column.name
6564
horizon = self.spec.horizon
6665
self.datasets.datetime_col = date_column
@@ -71,19 +70,22 @@ def _build_model(self) -> pd.DataFrame:
7170
self.errors_dict = dict()
7271

7372
# Clean up kwargs for pass through
74-
model_kwargs_cleaned = self.spec.model_kwargs.copy()
75-
model_kwargs_cleaned["n_algos_tuned"] = model_kwargs_cleaned.get(
76-
"n_algos_tuned", AUTOMLX_N_ALGOS_TUNED
77-
)
78-
model_kwargs_cleaned["score_metric"] = AUTOMLX_METRIC_MAP.get(
79-
self.spec.metric,
80-
model_kwargs_cleaned.get("score_metric", AUTOMLX_DEFAULT_SCORE_METRIC),
81-
)
82-
model_kwargs_cleaned.pop("task", None)
83-
time_budget = model_kwargs_cleaned.pop("time_budget", 0)
84-
model_kwargs_cleaned[
85-
"preprocessing"
86-
] = self.spec.preprocessing or model_kwargs_cleaned.get("preprocessing", True)
73+
model_kwargs_cleaned = None
74+
75+
if self.loaded_models is None:
76+
model_kwargs_cleaned = self.spec.model_kwargs.copy()
77+
model_kwargs_cleaned["n_algos_tuned"] = model_kwargs_cleaned.get(
78+
"n_algos_tuned", AUTOMLX_N_ALGOS_TUNED
79+
)
80+
model_kwargs_cleaned["score_metric"] = AUTOMLX_METRIC_MAP.get(
81+
self.spec.metric,
82+
model_kwargs_cleaned.get("score_metric", AUTOMLX_DEFAULT_SCORE_METRIC),
83+
)
84+
model_kwargs_cleaned.pop("task", None)
85+
time_budget = model_kwargs_cleaned.pop("time_budget", 0)
86+
model_kwargs_cleaned[
87+
"preprocessing"
88+
] = self.spec.preprocessing or model_kwargs_cleaned.get("preprocessing", True)
8789

8890
for i, (target, df) in enumerate(full_data_dict.items()):
8991
try:
@@ -107,15 +109,18 @@ def _build_model(self) -> pd.DataFrame:
107109
if y_train.index.is_monotonic
108110
else "NOT" + "monotonic."
109111
)
110-
model = automl.Pipeline(
111-
task="forecasting",
112-
**model_kwargs_cleaned,
113-
)
114-
model.fit(
115-
X=y_train.drop(target, axis=1),
116-
y=pd.DataFrame(y_train[target]),
117-
time_budget=time_budget,
118-
)
112+
model = self.loaded_models[target] if self.loaded_models is not None else None
113+
114+
if model is None:
115+
model = automl.Pipeline(
116+
task="forecasting",
117+
**model_kwargs_cleaned,
118+
)
119+
model.fit(
120+
X=y_train.drop(target, axis=1),
121+
y=pd.DataFrame(y_train[target]),
122+
time_budget=time_budget,
123+
)
119124
logger.debug("Selected model: {}".format(model.selected_model_))
120125
logger.debug(
121126
"Selected model params: {}".format(model.selected_model_params_)
@@ -142,12 +147,8 @@ def _build_model(self) -> pd.DataFrame:
142147
)
143148

144149
# Collect Outputs
145-
selected_models[target] = {
146-
"series_id": target,
147-
"selected_model": model.selected_model_,
148-
"model_params": model.selected_model_params_,
149-
}
150-
models[target] = model
150+
if self.loaded_models is None:
151+
models[target] = model
151152
summary_frame = summary_frame.rename_axis("ds").reset_index()
152153
summary_frame = summary_frame.rename(
153154
columns={
@@ -162,6 +163,24 @@ def _build_model(self) -> pd.DataFrame:
162163
summary_frame["yhat_lower"] = np.NAN
163164
outputs[target] = summary_frame
164165
# outputs_legacy[target] = summary_frame
166+
167+
self.model_parameters[target] = {
168+
"framework": SupportedModels.AutoMLX,
169+
"score_metric": model.score_metric,
170+
"random_state": model.random_state,
171+
"model_list": model.model_list,
172+
"n_algos_tuned": model.n_algos_tuned,
173+
"adaptive_sampling": model.adaptive_sampling,
174+
"min_features": model.min_features,
175+
"optimization": model.optimization,
176+
"preprocessing": model.preprocessing,
177+
"search_space": model.search_space,
178+
"time_series_period": model.time_series_period,
179+
"min_class_instances": model.min_class_instances,
180+
"max_tuning_trials": model.max_tuning_trials,
181+
"selected_model": model.selected_model_,
182+
"selected_model_params": model.selected_model_params_,
183+
}
165184
except Exception as e:
166185
self.errors_dict[target] = {"model_name": self.spec.model, "error": str(e)}
167186

@@ -191,7 +210,8 @@ def _build_model(self) -> pd.DataFrame:
191210
# output_col = output_col.reset_index(drop=True)
192211
# outputs_merged = pd.concat([outputs_merged, output_col], axis=1)
193212

194-
self.models = models
213+
self.models = models if self.loaded_models is None else self.loaded_models
214+
195215
return outputs_merged
196216

197217
@runtime_dependency(
@@ -262,7 +282,7 @@ def _generate_report(self):
262282
global_explanation_df = pd.DataFrame(self.global_explanation)
263283

264284
self.formatted_global_explanation = (
265-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
285+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
266286
)
267287

268288
# Create a markdown section for the global explainability
@@ -285,7 +305,7 @@ def _generate_report(self):
285305
dp.DataTable(
286306
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
287307
label=s_id,
288-
)
308+
)
289309
for s_id, local_ex_df in self.local_explanation.items()
290310
]
291311
local_explanation_section = (

0 commit comments

Comments
 (0)