@@ -61,6 +61,18 @@ def _train_model(self, data_train, data_test, model_kwargs):
6161 "verbosity" : - 1 ,
6262 "num_leaves" : 512 ,
6363 }
64+ additional_data_params = {}
65+ if len (self .datasets .get_additional_data_column_names ()) > 0 :
66+ additional_data_params = {
67+ "target_transforms" : [Differences ([12 ])],
68+ "lags" : model_kwargs .get ("lags" , [1 , 6 , 12 ]),
69+ "lag_transforms" : (
70+ {
71+ 1 : [ExpandingMean ()],
72+ 12 : [RollingMean (window_size = 24 )],
73+ }
74+ ),
75+ }
6476
6577 fcst = MLForecast (
6678 models = {
@@ -80,24 +92,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
8092 },
8193 freq = pd .infer_freq (data_train [self .date_col ].drop_duplicates ())
8294 or pd .infer_freq (data_train [self .date_col ].drop_duplicates ()[- 5 :]),
83- target_transforms = [Differences ([12 ])],
84- lags = model_kwargs .get (
85- "lags" ,
86- (
87- [1 , 6 , 12 ]
88- if len (self .datasets .get_additional_data_column_names ()) > 0
89- else []
90- ),
91- ),
92- lag_transforms = (
93- {
94- 1 : [ExpandingMean ()],
95- 12 : [RollingMean (window_size = 24 )],
96- }
97- if len (self .datasets .get_additional_data_column_names ()) > 0
98- else {}
99- ),
100- # date_features=[hour_index],
95+ ** additional_data_params ,
10196 )
10297
10398 num_models = model_kwargs .get ("recursive_models" , False )
@@ -164,6 +159,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
164159 "error" : str (e ),
165160 }
166161 logger .debug (f"Encountered Error: { e } . Skipping." )
162+ raise e
167163
168164 def _build_model (self ) -> pd .DataFrame :
169165 data_train = self .datasets .get_all_data_long (include_horizon = False )
0 commit comments