Skip to content

Commit e7255fe

Browse files
committed
small changes
1 parent 48bc34e commit e7255fe

File tree

5 files changed

+29
-15
lines changed

5 files changed

+29
-15
lines changed

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,9 @@ def _build_model(self) -> pd.DataFrame:
127127
outputs[target] = forecast
128128

129129
params = vars(model).copy()
130-
params.pop("arima_res_")
131-
params.pop("endog_index_")
130+
for param in ['arima_res', 'endog_index_']:
131+
if param in params:
132+
params.pop(param)
132133
self.model_parameters[target] = {
133134
"framework": SupportedModels.Arima,
134135
**params,
@@ -195,7 +196,7 @@ def _generate_report(self):
195196
global_explanation_df = pd.DataFrame(self.global_explanation)
196197

197198
self.formatted_global_explanation = (
198-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
199+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
199200
)
200201

201202
# Create a markdown section for the global explainability

ads/opctl/operator/lowcode/forecast/model/autots.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ def _build_model(self) -> pd.DataFrame:
218218
"validation_indexes",
219219
"best_model",
220220
]:
221-
params.pop(param)
221+
if param in params:
222+
params.pop(param)
222223
self.model_parameters[cat_target] = {
223224
"framework": SupportedModels.AutoTS,
224225
**params,

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
5656

5757
# these fields are populated in the _build_model() method
5858
self.models = None
59-
# Relevant only for neuralprophet
60-
self.loaded_trainers = None
61-
self.trainers = None
6259

6360
# "outputs" is a list of outputs generated by the models. These should only be generated when the framework requires the original output for plotting
6461
self.outputs = None
@@ -559,13 +556,9 @@ def _save_report(
559556
output_dir=output_dir,
560557
storage_options=storage_options,
561558
)
562-
if self.trainers is not None:
563-
utils.write_pkl(
564-
obj=self.trainers,
565-
filename="trainer.pkl",
566-
output_dir=output_dir,
567-
storage_options=storage_options,
568-
)
559+
560+
self._save_model_specific_files(output_dir, storage_options)
561+
569562
logger.info(
570563
f"The outputs have been successfully "
571564
f"generated and placed into the directory: {output_dir}."
@@ -599,6 +592,13 @@ def _generate_train_metrics(self) -> pd.DataFrame:
599592
"""
600593
raise NotImplementedError
601594

595+
def _save_model_specific_files(self, output_dir, storage_options):
596+
"""
597+
The method that needs to be implemented on the particular model level
598+
"""
599+
pass
600+
601+
602602
@runtime_dependency(
603603
module="shap",
604604
err_msg=(

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
7070
super().__init__(config=config, datasets=datasets)
7171
self.train_metrics = True
7272
self.forecast_col_name = "yhat1"
73+
self.loaded_trainers = None
74+
self.trainers = None
7375

7476
def _build_model(self) -> pd.DataFrame:
7577
from neuralprophet import NeuralProphet
@@ -319,6 +321,15 @@ def objective(trial):
319321

320322
return output_col
321323

324+
def _save_model_specific_files(self, output_dir, storage_options):
325+
if self.spec.generate_model_pickle and self.trainers is not None:
326+
utils.write_pkl(
327+
obj=self.trainers,
328+
filename="trainer.pkl",
329+
output_dir=output_dir,
330+
storage_options=storage_options,
331+
)
332+
322333
def _generate_report(self):
323334
import datapane as dp
324335

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def objective(trial):
213213

214214
params = vars(model).copy()
215215
for param in ["history", "history_dates", "stan_fit"]:
216-
params.pop(param)
216+
if param in params:
217+
params.pop(param)
217218
self.model_parameters[target] = {
218219
"framework": SupportedModels.Prophet,
219220
**params,

0 commit comments

Comments
 (0)