1111from functools import lru_cache
1212import logging
1313import ads
14- from prophet import Prophet
15- from neuralprophet import NeuralProphet
16- from pmdarima import ARIMA
17- from autots import AutoTS
18- from automlx ._interface .forecaster import AutoForecaster
14+ from ads .opctl .operator .lowcode .common .utils import load_data
15+ from ads .opctl .operator .common .operator_config import InputData
1916
2017ads .set_auth ("resource_principal" )
2118
2926 Inference script. This script is used for prediction by scoring server when schema is known.
3027"""
3128
29+ AUTOTS = "autots"
30+ ARIMA = "arima"
31+ PROPHET = "prophet"
32+ NEURALPROPHET = "neuralprophet"
33+ AUTOMLX = "automlx"
34+
3235
3336@lru_cache (maxsize = 10 )
3437def load_model ():
@@ -132,14 +135,14 @@ def post_inference(yhat):
132135 return yhat
133136
134137
135- def get_forecast (future_df , series_id , model_object , date_col , target_column , target_cat_col , horizon ):
138+ def get_forecast (future_df , model_name , series_id , model_object , date_col , target_column , target_cat_col , horizon ):
136139 date_col_name = date_col ["name" ]
137140 date_col_format = date_col ["format" ]
138141 future_df [target_cat_col ] = future_df [target_cat_col ].astype ("str" )
139142 future_df [date_col_name ] = pd .to_datetime (
140143 future_df [date_col_name ], format = date_col_format
141144 )
142- if isinstance ( model_object , AutoTS ) :
145+ if model_name == AUTOTS :
143146 series_id_col = "Series"
144147 full_data_indexed = future_df .rename (columns = {target_cat_col : series_id_col })
145148 additional_regressors = list (
@@ -152,12 +155,12 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
152155 )
153156 pred_obj = model_object .predict (future_regressor = future_reg )
154157 return pred_obj .forecast [series_id ].tolist ()
155- elif series_id in model_object and isinstance ( model_object [ series_id ], Prophet ) :
158+ elif model_name == PROPHET and series_id in model_object :
156159 model = model_object [series_id ]
157160 processed = future_df .rename (columns = {date_col_name : 'ds' , target_column : 'y' })
158161 forecast = model .predict (processed )
159162 return forecast ['yhat' ].tolist ()
160- elif series_id in model_object and isinstance ( model_object [ series_id ], NeuralProphet ) :
163+ elif model_name == NEURALPROPHET and series_id in model_object :
161164 model = model_object [series_id ]
162165 model .restore_trainer ()
163166 accepted_regressors = list (model .config_regressors .keys ())
@@ -166,7 +169,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
166169 future ["y" ] = None
167170 forecast = model .predict (future )
168171 return forecast ['yhat1' ].tolist ()
169- elif series_id in model_object and isinstance ( model_object [ series_id ], ARIMA ) :
172+ elif model_name == ARIMA and series_id in model_object :
170173 model = model_object [series_id ]
171174 future_df = future_df .set_index (date_col_name )
172175 x_pred = future_df .drop (target_cat_col , axis = 1 )
@@ -177,7 +180,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
177180 )
178181 yhat_clean = pd .DataFrame (yhat , index = yhat .index , columns = ["yhat" ])
179182 return yhat_clean ['yhat' ].tolist ()
180- elif series_id in model_object and isinstance ( model_object [ series_id ], AutoForecaster ) :
183+ elif model_name == AUTOMLX and series_id in model_object :
181184 # automlx model
182185 model = model_object [series_id ]
183186 x_pred = future_df .drop (target_cat_col , axis = 1 )
@@ -188,7 +191,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
188191 )
189192 return forecast [target_column ].tolist ()
190193 else :
191- raise Exception ( f"Invalid model object type: { type (model_object ).__name__ } ." )
194+ raise Exception (f"Invalid model object type: { type (model_object ).__name__ } ." )
192195
193196
194197def predict (data , model = load_model ()) -> dict :
@@ -211,20 +214,26 @@ def predict(data, model=load_model()) -> dict:
211214 models = model ["models" ]
212215 specs = model ["spec" ]
213216 horizon = specs ["horizon" ]
217+ model_name = specs ["model" ]
214218 date_col = specs ["datetime_column" ]
215219 target_column = specs ["target_column" ]
216- forecasts = {}
217- uri = f"{ data ['additional_data_uri' ]} "
218220 target_category_column = specs ["target_category_columns" ][0 ]
219- signer = ads .common .auth .default_signer () if uri .lower ().startswith ("oci://" ) else {}
220- additional_data = pd .read_csv (uri , storage_options = signer )
221+
222+ try :
223+ input_data = InputData (** data ["additional_data" ])
224+ except TypeError as e :
225+ raise ValueError (f"Validation error: { e } " )
226+ additional_data = load_data (input_data )
227+
221228 unique_values = additional_data [target_category_column ].unique ()
229+ forecasts = {}
222230 for key in unique_values :
223231 try :
224232 s_id = str (key )
225233 filtered = additional_data [additional_data [target_category_column ] == key ]
226234 future = filtered .tail (horizon )
227- forecast = get_forecast (future , s_id , models , date_col , target_column , target_category_column , horizon )
235+ forecast = get_forecast (future , model_name , s_id , models , date_col ,
236+ target_column , target_category_column , horizon )
228237 forecasts [s_id ] = json .dumps (forecast )
229238 except Exception as e :
230239 raise RuntimeError (
0 commit comments