@@ -56,8 +56,8 @@ def set_kwargs(self):
5656 )
5757 return model_kwargs_cleaned , time_budget
5858
59- def preprocess (self , data ): # TODO: re-use self.le for explanations
60- _ , df_encoded = _label_encode_dataframe (
59+ def preprocess (self , data , series_id ): # TODO: re-use self.le for explanations
60+ self . le [ series_id ] , df_encoded = _label_encode_dataframe (
6161 data ,
6262 no_encode = {self .spec .datetime_column .name , self .original_target_column },
6363 )
@@ -125,7 +125,7 @@ def _build_model(self) -> pd.DataFrame:
125125 self .forecast_output .init_series_output (
126126 series_id = s_id , data_at_series = df
127127 )
128- data = self .preprocess (df )
128+ data = self .preprocess (df , s_id )
129129 data_i = self .drop_horizon (data )
130130 X_pred = self .get_horizon (data ).drop (target , axis = 1 )
131131
@@ -157,7 +157,9 @@ def _build_model(self) -> pd.DataFrame:
157157 target
158158 ].values
159159
160- self .models [s_id ] = model
160+ self .models [s_id ] = {}
161+ self .models [s_id ]["model" ] = model
162+ self .models [s_id ]["le" ] = self .le [s_id ]
161163
162164 # In case of Naive model, model.forecast function call does not return confidence intervals.
163165 if f"{ target } _ci_upper" not in summary_frame :
@@ -218,7 +220,8 @@ def _generate_report(self):
218220 other_sections = []
219221
220222 if len (self .models ) > 0 :
221- for s_id , m in models .items ():
223+ for s_id , artifacts in models .items ():
224+ m = artifacts ["model" ]
222225 selected_models [s_id ] = {
223226 "series_id" : s_id ,
224227 "selected_model" : m .selected_model_ ,
@@ -320,7 +323,7 @@ def _generate_report(self):
320323 )
321324
322325 def get_explain_predict_fn (self , series_id ):
323- selected_model = self .models [series_id ]
326+ selected_model = self .models [series_id ][ "model" ]
324327
325328 # If training date, use method below. If future date, use forecast!
326329 def _custom_predict_fn (
@@ -338,12 +341,12 @@ def _custom_predict_fn(
338341 data [dt_column_name ] = seconds_to_datetime (
339342 data [dt_column_name ], dt_format = self .spec .datetime_column .format
340343 )
341- data = self .preprocess (data )
344+ data = self .preprocess (data , series_id )
342345 horizon_data = horizon_data .drop (target_col , axis = 1 )
343346 horizon_data [dt_column_name ] = seconds_to_datetime (
344347 horizon_data [dt_column_name ], dt_format = self .spec .datetime_column .format
345348 )
346- horizon_data = self .preprocess (horizon_data )
349+ horizon_data = self .preprocess (horizon_data , series_id )
347350
348351 rows = []
349352 for i in range (data .shape [0 ]):
@@ -421,7 +424,7 @@ def explain_model(self):
421424 if self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX :
422425 # Use the MLExplainer class from AutoMLx to generate explanations
423426 explainer = automlx .MLExplainer (
424- self .models [s_id ],
427+ self .models [s_id ][ "model" ] ,
425428 self .datasets .additional_data .get_data_for_series (series_id = s_id )
426429 .drop (self .spec .datetime_column .name , axis = 1 )
427430 .head (- self .spec .horizon )
0 commit comments