77import pandas as pd
88import numpy as np
99import pmdarima as pm
10+ from joblib import Parallel , delayed
1011
1112from ads .opctl import logger
1213
1516from ..operator_config import ForecastOperatorConfig
1617import traceback
1718from .forecast_datasets import ForecastDatasets , ForecastOutput
18- from ..const import ForecastOutputColumns , SupportedModels
19+ from ..const import ForecastOutputColumns
1920
2021
2122class ArimaOperatorModel (ForecastOperatorBaseModel ):
@@ -29,32 +30,31 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
2930 self .formatted_global_explanation = None
3031 self .formatted_local_explanation = None
3132
32- def _build_model (self ) -> pd .DataFrame :
33- full_data_dict = self .datasets .full_data_dict
34-
35- # Extract the Confidence Interval Width and convert to arima's equivalent - alpha
36- if self .spec .confidence_interval_width is None :
37- self .spec .confidence_interval_width = 1 - self .spec .model_kwargs .get (
38- "alpha" , 0.05
39- )
40- model_kwargs = self .spec .model_kwargs
41- model_kwargs ["alpha" ] = 1 - self .spec .confidence_interval_width
42- if "error_action" not in model_kwargs .keys ():
43- model_kwargs ["error_action" ] = "ignore"
33+ def _train_model (self , i , target , df ):
34+ """Trains the ARIMA model for a given target.
4435
45- models = []
46- self .datasets .datetime_col = self .spec .datetime_column .name
47- self .forecast_output = ForecastOutput (
48- confidence_interval_width = self .spec .confidence_interval_width
49- )
36+ Parameters
37+ ----------
38+ i: int
39+ The index of the target
40+ target: str
41+ The name of the target
42+ df: pd.DataFrame
43+ The dataframe containing the target data
44+ """
45+ try :
46+ # Extract the Confidence Interval Width and convert to arima's equivalent - alpha
47+ if self .spec .confidence_interval_width is None :
48+ self .spec .confidence_interval_width = 1 - self .spec .model_kwargs .get (
49+ "alpha" , 0.05
50+ )
51+ model_kwargs = self .spec .model_kwargs
52+ model_kwargs ["alpha" ] = 1 - self .spec .confidence_interval_width
53+ if "error_action" not in model_kwargs .keys ():
54+ model_kwargs ["error_action" ] = "ignore"
5055
51- outputs = dict ()
52- outputs_legacy = []
53- fitted_values = dict ()
54- actual_values = dict ()
55- dt_columns = dict ()
56+ # models = []
5657
57- for i , (target , df ) in enumerate (full_data_dict .items ()):
5858 # format the dataframe for this target. Dropping NA on target[df] will remove all future data
5959 le , df_encoded = utils ._label_encode_dataframe (
6060 df , no_encode = {self .spec .datetime_column .name , target }
@@ -72,34 +72,28 @@ def _build_model(self) -> pd.DataFrame:
7272 target ,
7373 self .spec .datetime_column .name ,
7474 }
75- logger .debug (
76- f"Additional Regressors Detected { list (additional_regressors )} "
77- )
75+ logger .debug (f"Additional Regressors Detected { list (additional_regressors )} " )
7876
7977 # Split data into X and y for arima tune method
8078 y = data_i [target ]
8179 X_in = None
8280 if len (additional_regressors ):
8381 X_in = data_i .drop (target , axis = 1 )
8482
85- model = self .loaded_models [i ] if self .loaded_models is not None else None
86- if model is None :
87- # Build and fit model
88- model = pm .auto_arima (y = y , X = X_in , ** self .spec .model_kwargs )
83+ # Build and fit model
84+ model = pm .auto_arima (y = y , X = X_in , ** self .spec .model_kwargs )
8985
90- fitted_values [target ] = model .predict_in_sample (X = X_in )
91- actual_values [target ] = y
92- actual_values [target ].index = pd .to_datetime (y .index )
86+ self . fitted_values [target ] = model .predict_in_sample (X = X_in )
87+ self . actual_values [target ] = y
88+ self . actual_values [target ].index = pd .to_datetime (y .index )
9389
9490 # Build future dataframe
9591 start_date = y .index .values [- 1 ]
9692 n_periods = self .spec .horizon
9793 if len (additional_regressors ):
9894 X = df_clean [df_clean [target ].isnull ()].drop (target , axis = 1 )
9995 else :
100- X = pd .date_range (
101- start = start_date , periods = n_periods , freq = self .spec .freq
102- )
96+ X = pd .date_range (start = start_date , periods = n_periods , freq = self .spec .freq )
10397
10498 # Predict and format forecast
10599 yhat , conf_int = model .predict (
@@ -110,7 +104,7 @@ def _build_model(self) -> pd.DataFrame:
110104 )
111105 yhat_clean = pd .DataFrame (yhat , index = yhat .index , columns = ["yhat" ])
112106
113- dt_columns [target ] = df_encoded [self .spec .datetime_column .name ]
107+ self . dt_columns [target ] = df_encoded [self .spec .datetime_column .name ]
114108 conf_int_clean = pd .DataFrame (
115109 conf_int , index = yhat .index , columns = ["yhat_lower" , "yhat_upper" ]
116110 )
@@ -119,25 +113,42 @@ def _build_model(self) -> pd.DataFrame:
119113 logger .debug (forecast [["yhat" , "yhat_lower" , "yhat_upper" ]].tail ())
120114
121115 # Collect all outputs
122- if self .loaded_models is None :
123- models .append (model )
124- outputs_legacy .append (
116+ # models.append(model)
117+ self .outputs_legacy .append (
125118 forecast .reset_index ().rename (columns = {"index" : "ds" })
126119 )
127- outputs [target ] = forecast
128-
129- params = vars (model ).copy ()
130- for param in ['arima_res' , 'endog_index_' ]:
131- if param in params :
132- params .pop (param )
133- self .model_parameters [target ] = {
134- "framework" : SupportedModels .Arima ,
135- ** params ,
136- }
120+ self .outputs [target ] = forecast
121+
122+ self .models_dict [target ] = model
137123
138- self .models = self .loaded_models if self .loaded_models is not None else models
124+ logger .debug ("===========Done===========" )
125+ except Exception as e :
126+ self .errors_dict [target ] = {"model_name" : self .spec .model , "error" : str (e )}
127+
128+ def _build_model (self ) -> pd .DataFrame :
129+ full_data_dict = self .datasets .full_data_dict
130+
131+ self .datasets .datetime_col = self .spec .datetime_column .name
132+ self .forecast_output = ForecastOutput (
133+ confidence_interval_width = self .spec .confidence_interval_width
134+ )
135+
136+ self .outputs = dict ()
137+ self .outputs_legacy = []
138+ self .fitted_values = dict ()
139+ self .actual_values = dict ()
140+ self .dt_columns = dict ()
141+ self .models_dict = dict ()
142+ self .errors_dict = dict ()
143+
144+ Parallel (n_jobs = - 1 , require = "sharedmem" )(
145+ delayed (ArimaOperatorModel ._train_model )(self , i , target , df )
146+ for self , (i , (target , df )) in zip (
147+ [self ] * len (full_data_dict ), enumerate (full_data_dict .items ())
148+ )
149+ )
139150
140- logger . debug ( "===========Done===========" )
151+ self . models = [ self . models_dict [ target ] for target in self . target_columns ]
141152
142153 # Merge the outputs from each model into 1 df with all outputs by target and category
143154 col = self .original_target_column
@@ -146,15 +157,15 @@ def _build_model(self) -> pd.DataFrame:
146157 yhat_lower_name = ForecastOutputColumns .LOWER_BOUND
147158 for cat in self .categories :
148159 output_i = pd .DataFrame ()
149- output_i ["Date" ] = dt_columns [f"{ col } _{ cat } " ]
160+ output_i ["Date" ] = self . dt_columns [f"{ col } _{ cat } " ]
150161 output_i ["Series" ] = cat
151162 output_i = output_i .set_index ("Date" )
152163
153- output_i ["input_value" ] = actual_values [f"{ col } _{ cat } " ]
154- output_i ["fitted_value" ] = fitted_values [f"{ col } _{ cat } " ]
155- output_i ["forecast_value" ] = outputs [f"{ col } _{ cat } " ]["yhat" ]
156- output_i [yhat_upper_name ] = outputs [f"{ col } _{ cat } " ]["yhat_upper" ]
157- output_i [yhat_lower_name ] = outputs [f"{ col } _{ cat } " ]["yhat_lower" ]
164+ output_i ["input_value" ] = self . actual_values [f"{ col } _{ cat } " ]
165+ output_i ["fitted_value" ] = self . fitted_values [f"{ col } _{ cat } " ]
166+ output_i ["forecast_value" ] = self . outputs [f"{ col } _{ cat } " ]["yhat" ]
167+ output_i [yhat_upper_name ] = self . outputs [f"{ col } _{ cat } " ]["yhat_upper" ]
168+ output_i [yhat_lower_name ] = self . outputs [f"{ col } _{ cat } " ]["yhat_lower" ]
158169
159170 output_i = output_i .reset_index (drop = False )
160171 output_col = pd .concat ([output_col , output_i ])
@@ -196,7 +207,7 @@ def _generate_report(self):
196207 global_explanation_df = pd .DataFrame (self .global_explanation )
197208
198209 self .formatted_global_explanation = (
199- global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
210+ global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
200211 )
201212
202213 # Create a markdown section for the global explainability
@@ -264,7 +275,7 @@ def _custom_predict_arima(self, data):
264275
265276 """
266277 date_col = self .spec .datetime_column .name
267- data [date_col ] = pd .to_datetime (data [date_col ], unit = 's' )
278+ data [date_col ] = pd .to_datetime (data [date_col ], unit = "s" )
268279 data = data .set_index (date_col )
269280 # Get the index of the current series id
270281 series_index = self .target_columns .index (self .series_id )
0 commit comments