@@ -29,6 +29,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
2929 self .local_explanation = {}
3030 self .formatted_global_explanation = None
3131 self .formatted_local_explanation = None
32+ self .constant_cols = {}
3233
3334 def set_kwargs (self ):
3435 # Extract the Confidence Interval Width and convert to arima's equivalent - alpha
@@ -64,6 +65,10 @@ def _train_model(self, i, s_id, df, model_kwargs):
6465 try :
6566 target = self .original_target_column
6667 self .forecast_output .init_series_output (series_id = s_id , data_at_series = df )
68+ # If trend is constant, remove constant columns
69+ if 'trend' not in model_kwargs or model_kwargs ['trend' ] == 'c' :
70+ self .constant_cols [s_id ] = df .columns [df .nunique () == 1 ]
71+ df = df .drop (columns = self .constant_cols [s_id ])
6772
6873 # format the dataframe for this target. Dropping NA on target[df] will remove all future data
6974 data = self .preprocess (df , s_id )
@@ -74,7 +79,7 @@ def _train_model(self, i, s_id, df, model_kwargs):
7479 X_in = data_i .drop (target , axis = 1 ) if len (data_i .columns ) > 1 else None
7580 X_pred = self .get_horizon (data ).drop (target , axis = 1 )
7681
77- if self .loaded_models is not None :
82+ if self .loaded_models is not None and s_id in self . loaded_models :
7883 model = self .loaded_models [s_id ]
7984 else :
8085 # Build and fit model
@@ -143,17 +148,18 @@ def _build_model(self) -> pd.DataFrame:
143148 def _generate_report (self ):
144149 """The method that needs to be implemented on the particular model level."""
145150 import datapane as dp
146-
147- sec5_text = dp .Text (f"## ARIMA Model Parameters" )
148- blocks = [
149- dp .HTML (
150- m .summary ().as_html (),
151- label = s_id ,
152- )
153- for i , (s_id , m ) in enumerate (self .models .items ())
154- ]
155- sec5 = dp .Select (blocks = blocks ) if len (blocks ) > 1 else blocks [0 ]
156- all_sections = [sec5_text , sec5 ]
151+ all_sections = []
152+ if len (self .models ) > 0 :
153+ sec5_text = dp .Text (f"## ARIMA Model Parameters" )
154+ blocks = [
155+ dp .HTML (
156+ m .summary ().as_html (),
157+ label = s_id ,
158+ )
159+ for i , (s_id , m ) in enumerate (self .models .items ())
160+ ]
161+ sec5 = dp .Select (blocks = blocks ) if len (blocks ) > 1 else blocks [0 ]
162+ all_sections = [sec5_text , sec5 ]
157163
158164 if self .spec .generate_explanations :
159165 try :
@@ -239,6 +245,9 @@ def _custom_predict(
239245 """
240246 data: ForecastDatasets.get_data_at_series(s_id)
241247 """
248+ if series_id in self .constant_cols :
249+ data = data .drop (columns = self .constant_cols [series_id ])
250+
242251 data = data .drop ([target_col ], axis = 1 )
243252 data [dt_column_name ] = seconds_to_datetime (
244253 data [dt_column_name ], dt_format = self .spec .datetime_column .format
0 commit comments