Skip to content

Commit df3d876

Browse files
committed
Retraining prophet and arima models
1 parent 94a09c0 commit df3d876

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ def _train_model(self, i, s_id, df, model_kwargs):
8585
X_pred = self.get_horizon(data).drop(target, axis=1)
8686

8787
if self.loaded_models is not None and s_id in self.loaded_models:
88-
model = self.loaded_models[s_id]
88+
model = self.loaded_models[s_id]["model"]
89+
order = model.order
90+
seasonal_order = model.seasonal_order
91+
model = pm.ARIMA(order=order, seasonal_order=seasonal_order)
92+
model.fit(y=y, X=X_in)
8993
else:
9094
# Build and fit model
9195
model = pm.auto_arima(y=y, X=X_in, **model_kwargs)

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import matplotlib as mpl
1010
import numpy as np
1111
import optuna
12+
import inspect
1213
import pandas as pd
1314
from joblib import Parallel, delayed
1415

@@ -39,6 +40,22 @@ def _add_unit(num, unit):
3940
return f"{num} {unit}"
4041

4142

43+
def _extract_parameter(model):
44+
"""
45+
extract Prophet initialization parameters
46+
"""
47+
from prophet import Prophet
48+
sig = inspect.signature(Prophet.__init__)
49+
param_names = list(sig.parameters.keys())
50+
params = {}
51+
for name in param_names:
52+
if hasattr(model, name):
53+
value = getattr(model, name)
54+
if isinstance(value, (int, float, str, bool, type(None), dict, list)):
55+
params[name] = value
56+
return params
57+
58+
4259
def _fit_model(data, params, additional_regressors):
4360
from prophet import Prophet
4461

@@ -96,16 +113,17 @@ def _train_model(self, i, series_id, df, model_kwargs):
96113
data = self.preprocess(df, series_id)
97114
data_i = self.drop_horizon(data)
98115
if self.loaded_models is not None and series_id in self.loaded_models:
99-
model = self.loaded_models[series_id]
116+
previous_model = self.loaded_models[series_id]["model"]
117+
model_kwargs.update(_extract_parameter(previous_model))
100118
else:
101119
if self.perform_tuning:
102120
model_kwargs = self.run_tuning(data_i, model_kwargs)
103121

104-
model = _fit_model(
105-
data=data,
106-
params=model_kwargs,
107-
additional_regressors=self.additional_regressors,
108-
)
122+
model = _fit_model(
123+
data=data,
124+
params=model_kwargs,
125+
additional_regressors=self.additional_regressors,
126+
)
109127

110128
# Get future df for prediction
111129
future = data.drop("y", axis=1)

0 commit comments

Comments
 (0)