@@ -149,36 +149,42 @@ def _train_model(self, i, s_id, df, model_kwargs):
149149 logger .debug (f"-----------------Model { i } ----------------------" )
150150 logger .debug (forecast .tail ())
151151
152- # TODO; could also extract trend and seasonality?
153- cols_to_read = set (
154- forecast .columns [forecast .columns .str .startswith ("future_regressor" )]
155- + ["ds" , "trend" ]
156- )
157- cols_to_read = cols_to_read - {
158- "future_regressors_additive" ,
159- "future_regressors_multiplicative" ,
160- }
161- combine_terms = cols_to_read - set (self .accepted_regressors [s_id ])
162- temp_df = (
163- forecast [list (cols_to_read )]
164- .rename ({"ds" : "Date" }, axis = 1 )
165- .set_index ("Date" )
166- )
167- temp_df [self .spec .target_column ] = temp_df [combine_terms ].sum (axis = 1 )
168- self .explanations_info [s_id ] = temp_df .drop (combine_terms , axis = 1 )
169-
170152 self .outputs [s_id ] = forecast
153+ upper_bound_col_name = f"yhat1 { model_kwargs ['quantiles' ][1 ]* 100 } %"
154+ lower_bound_col_name = f"yhat1 { model_kwargs ['quantiles' ][0 ]* 100 } %"
171155 self .forecast_output .populate_series_output (
172156 series_id = s_id ,
173157 fit_val = self .drop_horizon (forecast ["yhat1" ]).values ,
174158 forecast_val = self .get_horizon (forecast ["yhat1" ]).values ,
175- upper_bound = self .get_horizon (
176- forecast [f"yhat1 { model_kwargs ['quantiles' ][1 ]* 100 } %" ]
177- ).values ,
178- lower_bound = self .get_horizon (
179- forecast [f"yhat1 { model_kwargs ['quantiles' ][0 ]* 100 } %" ]
180- ).values ,
159+ upper_bound = self .get_horizon (forecast [upper_bound_col_name ]).values ,
160+ lower_bound = self .get_horizon (forecast [lower_bound_col_name ]).values ,
181161 )
162+ core_columns = set (forecast .columns ) - set (
163+ [
164+ "y" ,
165+ "yhat1" ,
166+ upper_bound_col_name ,
167+ lower_bound_col_name ,
168+ "future_regressors_additive" ,
169+ "future_regressors_multiplicative" ,
170+ ]
171+ )
172+ exog_variables = set (
173+ filter (lambda x : x .startswith ("future_regressor_" ), list (core_columns ))
174+ )
175+ combine_terms = list (core_columns - exog_variables - set (["ds" ]))
176+ temp_df = (
177+ forecast [list (core_columns )]
178+ .rename ({"ds" : "Date" }, axis = 1 )
179+ .set_index ("Date" )
180+ )
181+ if combine_terms :
182+ temp_df [self .spec .target_column ] = temp_df [combine_terms ].sum (axis = 1 )
183+ temp_df = temp_df .drop (combine_terms , axis = 1 )
184+ else :
185+ temp_df [self .spec .target_column ] = 0
186+ # Todo: check for columns that were dropped, and set them to 0
187+ self .explanations_info [s_id ] = temp_df
182188
183189 self .trainers [s_id ] = model .trainer
184190 self .models [s_id ] = {}
0 commit comments