55
66from typing import Dict , List
77
8+ import numpy as np
89import pandas as pd
910
1011from ads .opctl import logger
1819 get_frequency_of_datetime ,
1920)
2021
21- from ..const import ForecastOutputColumns , SupportedModels , TROUBLESHOOTING_GUIDE
22- from ..operator_config import ForecastOperatorConfig
22+ from ..const import TROUBLESHOOTING_GUIDE , ForecastOutputColumns , SupportedModels
23+ from ..operator_config import ForecastOperatorConfig , PostprocessingSteps
2324
2425
2526class HistoricalData (AbstractData ):
2627 def __init__ (self , spec , historical_data = None , subset = None ):
27- super ().__init__ (spec = spec , name = "historical_data" , data = historical_data , subset = subset )
28+ super ().__init__ (
29+ spec = spec , name = "historical_data" , data = historical_data , subset = subset
30+ )
2831 self .subset = subset
2932
3033 def _ingest_data (self , spec ):
@@ -49,15 +52,19 @@ def _verify_dt_col(self, spec):
4952 f"{ SupportedModels .AutoMLX } requires data with a frequency of at least one hour. Please try using a different model,"
5053 " or select the 'auto' option."
5154 )
52- raise InvalidParameterError (f"{ message } "
53- f"\n Please refer to the troubleshooting guide at { TROUBLESHOOTING_GUIDE } for resolution steps." )
55+ raise InvalidParameterError (
56+ f"{ message } "
57+ f"\n Please refer to the troubleshooting guide at { TROUBLESHOOTING_GUIDE } for resolution steps."
58+ )
5459
5560
5661class AdditionalData (AbstractData ):
5762 def __init__ (self , spec , historical_data , additional_data = None , subset = None ):
5863 self .subset = subset
5964 if additional_data is not None :
60- super ().__init__ (spec = spec , name = "additional_data" , data = additional_data , subset = subset )
65+ super ().__init__ (
66+ spec = spec , name = "additional_data" , data = additional_data , subset = subset
67+ )
6168 self .additional_regressors = list (self .data .columns )
6269 elif spec .additional_data is not None :
6370 super ().__init__ (spec = spec , name = "additional_data" , subset = subset )
@@ -70,7 +77,7 @@ def __init__(self, spec, historical_data, additional_data=None, subset=None):
7077 )
7178 elif historical_data .get_max_time () != add_dates [- (spec .horizon + 1 )]:
7279 raise DataMismatchError (
73- f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on { historical_data .get_max_time ()} . The additonal data horizon starts after { add_dates [- (spec .horizon + 1 )]} . These should be the same date."
80+ f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on { historical_data .get_max_time ()} . The additonal data horizon starts after { add_dates [- (spec .horizon + 1 )]} . These should be the same date."
7481 f"\n Please refer to the troubleshooting guide at { TROUBLESHOOTING_GUIDE } for resolution steps."
7582 )
7683 else :
@@ -150,7 +157,9 @@ def __init__(
150157 self ._datetime_column_name = config .spec .datetime_column .name
151158 self ._target_col = config .spec .target_column
152159 if historical_data is not None :
153- self .historical_data = HistoricalData (config .spec , historical_data , subset = subset )
160+ self .historical_data = HistoricalData (
161+ config .spec , historical_data , subset = subset
162+ )
154163 self .additional_data = AdditionalData (
155164 config .spec , self .historical_data , additional_data , subset = subset
156165 )
@@ -276,6 +285,7 @@ def __init__(
276285 horizon : int ,
277286 target_column : str ,
278287 dt_column : str ,
288+ postprocessing : PostprocessingSteps ,
279289 ):
280290 """Forecast Output contains all the details required to generate the forecast.csv output file.
281291
@@ -285,12 +295,14 @@ def __init__(
285295 horizon: int length of horizon
286296 target_column: str the name of the original target column
287297 dt_column: the name of the original datetime column
298+ postprocessing: postprocessing steps to be executed
288299 """
289300 self .series_id_map = {}
290301 self ._set_ci_column_names (confidence_interval_width )
291302 self .horizon = horizon
292303 self .target_column_name = target_column
293304 self .dt_column_name = dt_column
305+ self .postprocessing = postprocessing
294306
295307 def add_series_id (
296308 self ,
@@ -337,6 +349,12 @@ def populate_series_output(
337349 --------
338350 None
339351 """
352+ min_threshold , max_threshold = (
353+ self .postprocessing .set_min_forecast ,
354+ self .postprocessing .set_max_forecast ,
355+ )
356+ if min_threshold is not None or max_threshold is not None :
357+ np .clip (forecast_val , min_threshold , max_threshold , out = forecast_val )
340358 try :
341359 output_i = self .series_id_map [series_id ]
342360 except KeyError as e :
@@ -422,9 +440,9 @@ def _set_ci_column_names(self, confidence_interval_width):
422440
423441 def _check_forecast_format (self , forecast ):
424442 assert isinstance (forecast , pd .DataFrame )
425- assert (
426- len ( forecast . columns ) == 7
427- ), f"Expected just 7 columns, but got: { forecast . columns } "
443+ assert len ( forecast . columns ) == 7 , (
444+ f"Expected just 7 columns, but got: { forecast . columns } "
445+ )
428446 assert ForecastOutputColumns .DATE in forecast .columns
429447 assert ForecastOutputColumns .SERIES in forecast .columns
430448 assert ForecastOutputColumns .INPUT_VALUE in forecast .columns
@@ -506,16 +524,30 @@ def set_errors_dict(self, errors_dict: Dict):
506524 def get_errors_dict (self ):
507525 return getattr (self , "errors_dict" , None )
508526
509- def merge (self , other : ' ForecastResults' ):
527+ def merge (self , other : " ForecastResults" ):
510528 """Merge another ForecastResults object into this one."""
511529 # Merge DataFrames if they exist, else just set
512530 for attr in [
513- 'forecast' , 'metrics' , 'test_metrics' , 'local_explanations' , 'global_explanations' , 'model_parameters' , 'models' , 'errors_dict' ]:
531+ "forecast" ,
532+ "metrics" ,
533+ "test_metrics" ,
534+ "local_explanations" ,
535+ "global_explanations" ,
536+ "model_parameters" ,
537+ "models" ,
538+ "errors_dict" ,
539+ ]:
514540 val_self = getattr (self , attr , None )
515541 val_other = getattr (other , attr , None )
516542 if val_self is not None and val_other is not None :
517- if isinstance (val_self , pd .DataFrame ) and isinstance (val_other , pd .DataFrame ):
518- setattr (self , attr , pd .concat ([val_self , val_other ], ignore_index = True , axis = 0 ))
543+ if isinstance (val_self , pd .DataFrame ) and isinstance (
544+ val_other , pd .DataFrame
545+ ):
546+ setattr (
547+ self ,
548+ attr ,
549+ pd .concat ([val_self , val_other ], ignore_index = True , axis = 0 ),
550+ )
519551 elif isinstance (val_self , dict ) and isinstance (val_other , dict ):
520552 val_self .update (val_other )
521553 setattr (self , attr , val_self )
0 commit comments