1010from ads .common .decorator .runtime_dependency import runtime_dependency
1111from ads .opctl .operator .lowcode .forecast .const import (
1212 AUTOMLX_METRIC_MAP ,
13- ForecastOutputColumns ,
13+ ForecastOutputColumns , SupportedModels ,
1414)
1515from ads .opctl import logger
1616
@@ -60,7 +60,6 @@ def _build_model(self) -> pd.DataFrame:
6060 models = dict ()
6161 outputs = dict ()
6262 outputs_legacy = dict ()
63- selected_models = dict ()
6463 date_column = self .spec .datetime_column .name
6564 horizon = self .spec .horizon
6665 self .datasets .datetime_col = date_column
@@ -71,19 +70,22 @@ def _build_model(self) -> pd.DataFrame:
7170 self .errors_dict = dict ()
7271
7372 # Clean up kwargs for pass through
74- model_kwargs_cleaned = self .spec .model_kwargs .copy ()
75- model_kwargs_cleaned ["n_algos_tuned" ] = model_kwargs_cleaned .get (
76- "n_algos_tuned" , AUTOMLX_N_ALGOS_TUNED
77- )
78- model_kwargs_cleaned ["score_metric" ] = AUTOMLX_METRIC_MAP .get (
79- self .spec .metric ,
80- model_kwargs_cleaned .get ("score_metric" , AUTOMLX_DEFAULT_SCORE_METRIC ),
81- )
82- model_kwargs_cleaned .pop ("task" , None )
83- time_budget = model_kwargs_cleaned .pop ("time_budget" , 0 )
84- model_kwargs_cleaned [
85- "preprocessing"
86- ] = self .spec .preprocessing or model_kwargs_cleaned .get ("preprocessing" , True )
73+ model_kwargs_cleaned = None
74+
75+ if self .loaded_models is None :
76+ model_kwargs_cleaned = self .spec .model_kwargs .copy ()
77+ model_kwargs_cleaned ["n_algos_tuned" ] = model_kwargs_cleaned .get (
78+ "n_algos_tuned" , AUTOMLX_N_ALGOS_TUNED
79+ )
80+ model_kwargs_cleaned ["score_metric" ] = AUTOMLX_METRIC_MAP .get (
81+ self .spec .metric ,
82+ model_kwargs_cleaned .get ("score_metric" , AUTOMLX_DEFAULT_SCORE_METRIC ),
83+ )
84+ model_kwargs_cleaned .pop ("task" , None )
85+ time_budget = model_kwargs_cleaned .pop ("time_budget" , 0 )
86+ model_kwargs_cleaned [
87+ "preprocessing"
88+ ] = self .spec .preprocessing or model_kwargs_cleaned .get ("preprocessing" , True )
8789
8890 for i , (target , df ) in enumerate (full_data_dict .items ()):
8991 try :
@@ -107,15 +109,18 @@ def _build_model(self) -> pd.DataFrame:
107109 if y_train .index .is_monotonic
108110 else "NOT" + "monotonic."
109111 )
110- model = automl .Pipeline (
111- task = "forecasting" ,
112- ** model_kwargs_cleaned ,
113- )
114- model .fit (
115- X = y_train .drop (target , axis = 1 ),
116- y = pd .DataFrame (y_train [target ]),
117- time_budget = time_budget ,
118- )
112+ model = self .loaded_models [target ] if self .loaded_models is not None else None
113+
114+ if model is None :
115+ model = automl .Pipeline (
116+ task = "forecasting" ,
117+ ** model_kwargs_cleaned ,
118+ )
119+ model .fit (
120+ X = y_train .drop (target , axis = 1 ),
121+ y = pd .DataFrame (y_train [target ]),
122+ time_budget = time_budget ,
123+ )
119124 logger .debug ("Selected model: {}" .format (model .selected_model_ ))
120125 logger .debug (
121126 "Selected model params: {}" .format (model .selected_model_params_ )
@@ -142,12 +147,8 @@ def _build_model(self) -> pd.DataFrame:
142147 )
143148
144149 # Collect Outputs
145- selected_models [target ] = {
146- "series_id" : target ,
147- "selected_model" : model .selected_model_ ,
148- "model_params" : model .selected_model_params_ ,
149- }
150- models [target ] = model
150+ if self .loaded_models is None :
151+ models [target ] = model
151152 summary_frame = summary_frame .rename_axis ("ds" ).reset_index ()
152153 summary_frame = summary_frame .rename (
153154 columns = {
@@ -162,6 +163,24 @@ def _build_model(self) -> pd.DataFrame:
162163 summary_frame ["yhat_lower" ] = np .NAN
163164 outputs [target ] = summary_frame
164165 # outputs_legacy[target] = summary_frame
166+
167+ self .model_parameters [target ] = {
168+ "framework" : SupportedModels .AutoMLX ,
169+ "score_metric" : model .score_metric ,
170+ "random_state" : model .random_state ,
171+ "model_list" : model .model_list ,
172+ "n_algos_tuned" : model .n_algos_tuned ,
173+ "adaptive_sampling" : model .adaptive_sampling ,
174+ "min_features" : model .min_features ,
175+ "optimization" : model .optimization ,
176+ "preprocessing" : model .preprocessing ,
177+ "search_space" : model .search_space ,
178+ "time_series_period" : model .time_series_period ,
179+ "min_class_instances" : model .min_class_instances ,
180+ "max_tuning_trials" : model .max_tuning_trials ,
181+ "selected_model" : model .selected_model_ ,
182+ "selected_model_params" : model .selected_model_params_ ,
183+ }
165184 except Exception as e :
166185 self .errors_dict [target ] = {"model_name" : self .spec .model , "error" : str (e )}
167186
@@ -191,7 +210,8 @@ def _build_model(self) -> pd.DataFrame:
191210 # output_col = output_col.reset_index(drop=True)
192211 # outputs_merged = pd.concat([outputs_merged, output_col], axis=1)
193212
194- self .models = models
213+ self .models = models if self .loaded_models is None else self .loaded_models
214+
195215 return outputs_merged
196216
197217 @runtime_dependency (
@@ -262,7 +282,7 @@ def _generate_report(self):
262282 global_explanation_df = pd .DataFrame (self .global_explanation )
263283
264284 self .formatted_global_explanation = (
265- global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
285+ global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
266286 )
267287
268288 # Create a markdown section for the global explainability
@@ -285,7 +305,7 @@ def _generate_report(self):
285305 dp .DataTable (
286306 local_ex_df .div (local_ex_df .abs ().sum (axis = 1 ), axis = 0 ) * 100 ,
287307 label = s_id ,
288- )
308+ )
289309 for s_id , local_ex_df in self .local_explanation .items ()
290310 ]
291311 local_explanation_section = (
0 commit comments