@@ -56,8 +56,8 @@ def set_kwargs(self):
5656 )
5757 return model_kwargs_cleaned , time_budget
5858
59- def preprocess (self , data ): # TODO: re-use self.le for explanations
60- _ , df_encoded = _label_encode_dataframe (
59+ def preprocess (self , data , series_id ): # TODO: re-use self.le for explanations
60+ self . le [ series_id ] , df_encoded = _label_encode_dataframe (
6161 data ,
6262 no_encode = {self .spec .datetime_column .name , self .original_target_column },
6363 )
@@ -124,7 +124,7 @@ def _build_model(self) -> pd.DataFrame:
124124 self .forecast_output .init_series_output (
125125 series_id = s_id , data_at_series = df
126126 )
127- data = self .preprocess (df )
127+ data = self .preprocess (df , s_id )
128128 data_i = self .drop_horizon (data )
129129 X_pred = self .get_horizon (data ).drop (target , axis = 1 )
130130
@@ -156,7 +156,9 @@ def _build_model(self) -> pd.DataFrame:
156156 target
157157 ].values
158158
159- self .models [s_id ] = model
159+ self .models [s_id ] = {}
160+ self .models [s_id ]["model" ] = model
161+ self .models [s_id ]["le" ] = self .le [s_id ]
160162
161163 # In case of Naive model, model.forecast function call does not return confidence intervals.
162164 if f"{ target } _ci_upper" not in summary_frame :
@@ -217,7 +219,8 @@ def _generate_report(self):
217219 other_sections = []
218220
219221 if len (self .models ) > 0 :
220- for s_id , m in models .items ():
222+ for s_id , artifacts in models .items ():
223+ m = artifacts ["model" ]
221224 selected_models [s_id ] = {
222225 "series_id" : s_id ,
223226 "selected_model" : m .selected_model_ ,
@@ -323,7 +326,7 @@ def _generate_report(self):
323326 )
324327
325328 def get_explain_predict_fn (self , series_id ):
326- selected_model = self .models [series_id ]
329+ selected_model = self .models [series_id ][ "model" ]
327330
328331 # If training date, use method below. If future date, use forecast!
329332 def _custom_predict_fn (
@@ -341,12 +344,12 @@ def _custom_predict_fn(
341344 data [dt_column_name ] = seconds_to_datetime (
342345 data [dt_column_name ], dt_format = self .spec .datetime_column .format
343346 )
344- data = self .preprocess (data )
347+ data = self .preprocess (data , series_id )
345348 horizon_data = horizon_data .drop (target_col , axis = 1 )
346349 horizon_data [dt_column_name ] = seconds_to_datetime (
347350 horizon_data [dt_column_name ], dt_format = self .spec .datetime_column .format
348351 )
349- horizon_data = self .preprocess (horizon_data )
352+ horizon_data = self .preprocess (horizon_data , series_id )
350353
351354 rows = []
352355 for i in range (data .shape [0 ]):
@@ -424,10 +427,8 @@ def explain_model(self):
424427 if self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX :
425428 # Use the MLExplainer class from AutoMLx to generate explanations
426429 explainer = automlx .MLExplainer (
427- self .models [s_id ],
428- self .datasets .additional_data .get_data_for_series (
429- series_id = s_id
430- )
430+ self .models [s_id ]["model" ],
431+ self .datasets .additional_data .get_data_for_series (series_id = s_id )
431432 .drop (self .spec .datetime_column .name , axis = 1 )
432433 .head (- self .spec .horizon )
433434 if self .spec .additional_data
0 commit comments