77import pandas as pd
88import numpy as np
99import pmdarima as pm
10+ from joblib import Parallel , delayed
1011
1112from ads .opctl import logger
1213
@@ -29,32 +30,31 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
2930 self .formatted_global_explanation = None
3031 self .formatted_local_explanation = None
3132
32- def _build_model (self ) -> pd .DataFrame :
33- full_data_dict = self .datasets .full_data_dict
34-
35- # Extract the Confidence Interval Width and convert to arima's equivalent - alpha
36- if self .spec .confidence_interval_width is None :
37- self .spec .confidence_interval_width = 1 - self .spec .model_kwargs .get (
38- "alpha" , 0.05
39- )
40- model_kwargs = self .spec .model_kwargs
41- model_kwargs ["alpha" ] = 1 - self .spec .confidence_interval_width
42- if "error_action" not in model_kwargs .keys ():
43- model_kwargs ["error_action" ] = "ignore"
33+ def _train_model (self , i , target , df ):
34+ """Trains the ARIMA model for a given target.
4435
45- models = []
46- self .datasets .datetime_col = self .spec .datetime_column .name
47- self .forecast_output = ForecastOutput (
48- confidence_interval_width = self .spec .confidence_interval_width
49- )
36+ Parameters
37+ ----------
38+ i: int
39+ The index of the target
40+ target: str
41+ The name of the target
42+ df: pd.DataFrame
43+ The dataframe containing the target data
44+ """
45+ try :
46+ # Extract the Confidence Interval Width and convert to arima's equivalent - alpha
47+ if self .spec .confidence_interval_width is None :
48+ self .spec .confidence_interval_width = 1 - self .spec .model_kwargs .get (
49+ "alpha" , 0.05
50+ )
51+ model_kwargs = self .spec .model_kwargs
52+ model_kwargs ["alpha" ] = 1 - self .spec .confidence_interval_width
53+ if "error_action" not in model_kwargs .keys ():
54+ model_kwargs ["error_action" ] = "ignore"
5055
51- outputs = dict ()
52- outputs_legacy = []
53- fitted_values = dict ()
54- actual_values = dict ()
55- dt_columns = dict ()
56+ # models = []
5657
57- for i , (target , df ) in enumerate (full_data_dict .items ()):
5858 # format the dataframe for this target. Dropping NA on target[df] will remove all future data
5959 le , df_encoded = utils ._label_encode_dataframe (
6060 df , no_encode = {self .spec .datetime_column .name , target }
@@ -72,9 +72,7 @@ def _build_model(self) -> pd.DataFrame:
7272 target ,
7373 self .spec .datetime_column .name ,
7474 }
75- logger .debug (
76- f"Additional Regressors Detected { list (additional_regressors )} "
77- )
75+ logger .debug (f"Additional Regressors Detected { list (additional_regressors )} " )
7876
7977 # Split data into X and y for arima tune method
8078 y = data_i [target ]
@@ -85,19 +83,17 @@ def _build_model(self) -> pd.DataFrame:
8583 # Build and fit model
8684 model = pm .auto_arima (y = y , X = X_in , ** self .spec .model_kwargs )
8785
88- fitted_values [target ] = model .predict_in_sample (X = X_in )
89- actual_values [target ] = y
90- actual_values [target ].index = pd .to_datetime (y .index )
86+ self . fitted_values [target ] = model .predict_in_sample (X = X_in )
87+ self . actual_values [target ] = y
88+ self . actual_values [target ].index = pd .to_datetime (y .index )
9189
9290 # Build future dataframe
9391 start_date = y .index .values [- 1 ]
9492 n_periods = self .spec .horizon
9593 if len (additional_regressors ):
9694 X = df_clean [df_clean [target ].isnull ()].drop (target , axis = 1 )
9795 else :
98- X = pd .date_range (
99- start = start_date , periods = n_periods , freq = self .spec .freq
100- )
96+ X = pd .date_range (start = start_date , periods = n_periods , freq = self .spec .freq )
10197
10298 # Predict and format forecast
10399 yhat , conf_int = model .predict (
@@ -108,7 +104,7 @@ def _build_model(self) -> pd.DataFrame:
108104 )
109105 yhat_clean = pd .DataFrame (yhat , index = yhat .index , columns = ["yhat" ])
110106
111- dt_columns [target ] = df_encoded [self .spec .datetime_column .name ]
107+ self . dt_columns [target ] = df_encoded [self .spec .datetime_column .name ]
112108 conf_int_clean = pd .DataFrame (
113109 conf_int , index = yhat .index , columns = ["yhat_lower" , "yhat_upper" ]
114110 )
@@ -117,15 +113,42 @@ def _build_model(self) -> pd.DataFrame:
117113 logger .debug (forecast [["yhat" , "yhat_lower" , "yhat_upper" ]].tail ())
118114
119115 # Collect all outputs
120- models .append (model )
121- outputs_legacy .append (
116+ # models.append(model)
117+ self . outputs_legacy .append (
122118 forecast .reset_index ().rename (columns = {"index" : "ds" })
123119 )
124- outputs [target ] = forecast
120+ self .outputs [target ] = forecast
121+
122+ self .models_dict [target ] = model
125123
126- self .models = models
124+ logger .debug ("===========Done===========" )
125+ except Exception as e :
126+ self .errors_dict [target ] = {"model_name" : self .spec .model , "error" : str (e )}
127+
128+ def _build_model (self ) -> pd .DataFrame :
129+ full_data_dict = self .datasets .full_data_dict
130+
131+ self .datasets .datetime_col = self .spec .datetime_column .name
132+ self .forecast_output = ForecastOutput (
133+ confidence_interval_width = self .spec .confidence_interval_width
134+ )
135+
136+ self .outputs = dict ()
137+ self .outputs_legacy = []
138+ self .fitted_values = dict ()
139+ self .actual_values = dict ()
140+ self .dt_columns = dict ()
141+ self .models_dict = dict ()
142+ self .errors_dict = dict ()
143+
144+ Parallel (n_jobs = - 1 , require = "sharedmem" )(
145+ delayed (ArimaOperatorModel ._train_model )(self , i , target , df )
146+ for self , (i , (target , df )) in zip (
147+ [self ] * len (full_data_dict ), enumerate (full_data_dict .items ())
148+ )
149+ )
127150
128- logger . debug ( "===========Done===========" )
151+ self . models = [ self . models_dict [ target ] for target in self . target_columns ]
129152
130153 # Merge the outputs from each model into 1 df with all outputs by target and category
131154 col = self .original_target_column
@@ -134,15 +157,15 @@ def _build_model(self) -> pd.DataFrame:
134157 yhat_lower_name = ForecastOutputColumns .LOWER_BOUND
135158 for cat in self .categories :
136159 output_i = pd .DataFrame ()
137- output_i ["Date" ] = dt_columns [f"{ col } _{ cat } " ]
160+ output_i ["Date" ] = self . dt_columns [f"{ col } _{ cat } " ]
138161 output_i ["Series" ] = cat
139162 output_i = output_i .set_index ("Date" )
140163
141- output_i ["input_value" ] = actual_values [f"{ col } _{ cat } " ]
142- output_i ["fitted_value" ] = fitted_values [f"{ col } _{ cat } " ]
143- output_i ["forecast_value" ] = outputs [f"{ col } _{ cat } " ]["yhat" ]
144- output_i [yhat_upper_name ] = outputs [f"{ col } _{ cat } " ]["yhat_upper" ]
145- output_i [yhat_lower_name ] = outputs [f"{ col } _{ cat } " ]["yhat_lower" ]
164+ output_i ["input_value" ] = self . actual_values [f"{ col } _{ cat } " ]
165+ output_i ["fitted_value" ] = self . fitted_values [f"{ col } _{ cat } " ]
166+ output_i ["forecast_value" ] = self . outputs [f"{ col } _{ cat } " ]["yhat" ]
167+ output_i [yhat_upper_name ] = self . outputs [f"{ col } _{ cat } " ]["yhat_upper" ]
168+ output_i [yhat_lower_name ] = self . outputs [f"{ col } _{ cat } " ]["yhat_lower" ]
146169
147170 output_i = output_i .reset_index (drop = False )
148171 output_col = pd .concat ([output_col , output_i ])
@@ -252,7 +275,7 @@ def _custom_predict_arima(self, data):
252275
253276 """
254277 date_col = self .spec .datetime_column .name
255- data [date_col ] = pd .to_datetime (data [date_col ], unit = 's' )
278+ data [date_col ] = pd .to_datetime (data [date_col ], unit = "s" )
256279 data = data .set_index (date_col )
257280 # Get the index of the current series id
258281 series_index = self .target_columns .index (self .series_id )
0 commit comments