Skip to content

Commit 691b486

Browse files
committed
small changes
1 parent e7255fe commit 691b486

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,7 @@ def generate_report(self):
7979

8080
# load models if given
8181
if self.spec.previous_output_dir is not None:
82-
try:
83-
self.loaded_models = utils.load_pkl(self.spec.previous_output_dir + "/model.pkl")
84-
except:
85-
logger.info("model.pkl is not present")
82+
self._load_model()
8683

8784
start_time = time.time()
8885
result_df = self._build_model()
@@ -592,6 +589,12 @@ def _generate_train_metrics(self) -> pd.DataFrame:
592589
"""
593590
raise NotImplementedError
594591

592+
def _load_model(self):
593+
try:
594+
self.loaded_models = utils.load_pkl(self.spec.previous_output_dir + "/model.pkl")
595+
except:
596+
logger.info("model.pkl is not present")
597+
595598
def _save_model_specific_files(self, output_dir, storage_options):
596599
"""
597600
The method that needs to be implemented on the particular model level

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,21 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
7373
self.loaded_trainers = None
7474
self.trainers = None
7575

76+
def _load_model(self):
77+
try:
78+
self.loaded_models = utils.load_pkl(self.spec.previous_output_dir + "/model.pkl")
79+
self.loaded_trainers = utils.load_pkl(self.spec.previous_output_dir + "/trainer.pkl")
80+
except:
81+
logger.info("model.pkl/trainer.pkl is not present")
82+
83+
7684
def _build_model(self) -> pd.DataFrame:
7785
from neuralprophet import NeuralProphet
7886

7987
full_data_dict = self.datasets.full_data_dict
8088
models = []
8189
trainers = []
8290

83-
if self.loaded_models is not None:
84-
try:
85-
self.loaded_trainers = utils.load_pkl(self.spec.previous_output_dir + "/trainer.pkl")
86-
except:
87-
logger.info("trainer.pkl is not present")
88-
8991
outputs = dict()
9092
outputs_legacy = []
9193

0 commit comments

Comments
 (0)