@@ -56,8 +56,8 @@ def set_kwargs(self):
5656 )
5757 return model_kwargs_cleaned , time_budget
5858
59- def preprocess (self , data ): # TODO: re-use self.le for explanations
60- _ , df_encoded = _label_encode_dataframe (
59+ def preprocess (self , data , series_id ): # TODO: re-use self.le for explanations
60+ self . le [ series_id ] , df_encoded = _label_encode_dataframe (
6161 data ,
6262 no_encode = {self .spec .datetime_column .name , self .original_target_column },
6363 )
@@ -125,7 +125,7 @@ def _build_model(self) -> pd.DataFrame:
125125 self .forecast_output .init_series_output (
126126 series_id = s_id , data_at_series = df
127127 )
128- data = self .preprocess (df )
128+ data = self .preprocess (df , s_id )
129129 data_i = self .drop_horizon (data )
130130 X_pred = self .get_horizon (data ).drop (target , axis = 1 )
131131
@@ -157,7 +157,9 @@ def _build_model(self) -> pd.DataFrame:
157157 target
158158 ].values
159159
160- self .models [s_id ] = model
160+ self .models [s_id ] = {}
161+ self .models [s_id ]["model" ] = model
162+ self .models [s_id ]["le" ] = self .le [s_id ]
161163
162164 # In case of Naive model, model.forecast function call does not return confidence intervals.
163165 if f"{ target } _ci_upper" not in summary_frame :
@@ -218,7 +220,8 @@ def _generate_report(self):
218220 other_sections = []
219221
220222 if len (self .models ) > 0 :
221- for s_id , m in models .items ():
223+ for s_id , artifacts in models .items ():
224+ m = artifacts ["model" ]
222225 selected_models [s_id ] = {
223226 "series_id" : s_id ,
224227 "selected_model" : m .selected_model_ ,
@@ -326,7 +329,7 @@ def _generate_report(self):
326329 )
327330
328331 def get_explain_predict_fn (self , series_id ):
329- selected_model = self .models [series_id ]
332+ selected_model = self .models [series_id ][ "model" ]
330333
331334 # If training date, use method below. If future date, use forecast!
332335 def _custom_predict_fn (
@@ -344,12 +347,12 @@ def _custom_predict_fn(
344347 data [dt_column_name ] = seconds_to_datetime (
345348 data [dt_column_name ], dt_format = self .spec .datetime_column .format
346349 )
347- data = self .preprocess (data )
350+ data = self .preprocess (data , series_id )
348351 horizon_data = horizon_data .drop (target_col , axis = 1 )
349352 horizon_data [dt_column_name ] = seconds_to_datetime (
350353 horizon_data [dt_column_name ], dt_format = self .spec .datetime_column .format
351354 )
352- horizon_data = self .preprocess (horizon_data )
355+ horizon_data = self .preprocess (horizon_data , series_id )
353356
354357 rows = []
355358 for i in range (data .shape [0 ]):
0 commit comments