Skip to content

Commit 94a09c0

Browse files
committed
Retraining automlx models
1 parent 06dd598 commit 94a09c0

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,17 +218,12 @@ def _write_file(local_filename, remote_filename, storage_options, **kwargs):
218218

219219

220220
def load_pkl(filepath):
221-
return _safe_write(fn=_load_pkl, filepath=filepath)
222-
223-
224-
def _load_pkl(filepath):
225221
storage_options = {}
226222
if ObjectStorageDetails.is_oci_path(filepath):
227223
storage_options = default_signer()
228224

229225
with fsspec.open(filepath, "rb", **storage_options) as f:
230226
return cloudpickle.load(f)
231-
return None
232227

233228

234229
def write_pkl(obj, filename, output_dir, storage_options):

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,20 @@ def _build_model(self) -> pd.DataFrame:
142142
)
143143

144144
if self.loaded_models is not None and s_id in self.loaded_models:
145-
model = self.loaded_models[s_id]
146-
else:
147-
model = Pipeline(
148-
task="forecasting",
149-
**model_kwargs,
150-
)
151-
model.fit(
152-
X=data_i.drop(target, axis=1),
153-
y=data_i[[target]],
154-
time_budget=time_budget,
155-
)
145+
model = self.loaded_models[s_id]["model"]
146+
model_kwargs["model_list"] = [model.selected_model_]
147+
model_kwargs["search_space"]={}
148+
model_kwargs["search_space"][model.selected_model_] = model.selected_model_params_
149+
150+
model = Pipeline(
151+
task="forecasting",
152+
**model_kwargs,
153+
)
154+
model.fit(
155+
X=data_i.drop(target, axis=1),
156+
y=data_i[[target]],
157+
time_budget=time_budget,
158+
)
156159
logger.debug(f"Selected model: {model.selected_model_}")
157160
logger.debug(f"Selected model params: {model.selected_model_params_}")
158161
summary_frame = model.forecast(

0 commit comments

Comments
 (0)