@@ -19,6 +19,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
1919 self .local_explanation = {}
2020 self .formatted_global_explanation = None
2121 self .formatted_local_explanation = None
22+ self .date_col = config .spec .datetime_column .name
2223
2324 def set_kwargs (self ):
2425 """
@@ -73,8 +74,8 @@ def _train_model(self, data_train, data_test, model_kwargs):
7374 alpha = model_kwargs ["lower_quantile" ],
7475 ),
7576 },
76- freq = pd .infer_freq (data_train ["Date" ].drop_duplicates ())
77- or pd .infer_freq (data_train ["Date" ].drop_duplicates ()[- 5 :]),
77+ freq = pd .infer_freq (data_train [self . date_col ].drop_duplicates ())
78+ or pd .infer_freq (data_train [self . date_col ].drop_duplicates ()[- 5 :]),
7879 target_transforms = [Differences ([12 ])],
7980 lags = model_kwargs .get (
8081 "lags" ,
@@ -104,7 +105,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
104105 data_train [self .model_columns ],
105106 static_features = model_kwargs .get ("static_features" , []),
106107 id_col = ForecastOutputColumns .SERIES ,
107- time_col = self .spec . datetime_column . name ,
108+ time_col = self .date_col ,
108109 target_col = self .spec .target_column ,
109110 fitted = True ,
110111 max_horizon = None if num_models is False else self .spec .horizon ,
@@ -168,7 +169,7 @@ def _build_model(self) -> pd.DataFrame:
168169 confidence_interval_width = self .spec .confidence_interval_width ,
169170 horizon = self .spec .horizon ,
170171 target_column = self .original_target_column ,
171- dt_column = self .spec . datetime_column . name ,
172+ dt_column = self .date_col ,
172173 )
173174 self ._train_model (data_train , data_test , model_kwargs )
174175 return self .forecast_output .get_forecast_long ()
0 commit comments