1313import ads
1414from ads .opctl .operator .lowcode .common .utils import load_data
1515from ads .opctl .operator .common .operator_config import InputData
16+ from ads .opctl .operator .lowcode .forecast .const import SupportedModels
1617
1718ads .set_auth ("resource_principal" )
1819
2627 Inference script. This script is used for prediction by scoring server when schema is known.
2728"""
2829
29- AUTOTS = "autots"
30- ARIMA = "arima"
31- PROPHET = "prophet"
32- NEURALPROPHET = "neuralprophet"
33- AUTOMLX = "automlx"
34-
3530
3631@lru_cache (maxsize = 10 )
3732def load_model ():
@@ -142,7 +137,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
142137 future_df [date_col_name ] = pd .to_datetime (
143138 future_df [date_col_name ], format = date_col_format
144139 )
145- if model_name == AUTOTS :
140+ if model_name == SupportedModels . AutoTS :
146141 series_id_col = "Series"
147142 full_data_indexed = future_df .rename (columns = {target_cat_col : series_id_col })
148143 additional_regressors = list (
@@ -155,12 +150,12 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
155150 )
156151 pred_obj = model_object .predict (future_regressor = future_reg )
157152 return pred_obj .forecast [series_id ].tolist ()
158- elif model_name == PROPHET and series_id in model_object :
153+ elif model_name == SupportedModels . Prophet and series_id in model_object :
159154 model = model_object [series_id ]
160155 processed = future_df .rename (columns = {date_col_name : 'ds' , target_column : 'y' })
161156 forecast = model .predict (processed )
162157 return forecast ['yhat' ].tolist ()
163- elif model_name == NEURALPROPHET and series_id in model_object :
158+ elif model_name == SupportedModels . NeuralProphet and series_id in model_object :
164159 model = model_object [series_id ]
165160 model .restore_trainer ()
166161 accepted_regressors = list (model .config_regressors .regressors .keys ())
@@ -169,7 +164,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
169164 future ["y" ] = None
170165 forecast = model .predict (future )
171166 return forecast ['yhat1' ].tolist ()
172- elif model_name == ARIMA and series_id in model_object :
167+ elif model_name == SupportedModels . Arima and series_id in model_object :
173168 model = model_object [series_id ]
174169 future_df = future_df .set_index (date_col_name )
175170 x_pred = future_df .drop (target_cat_col , axis = 1 )
@@ -180,7 +175,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
180175 )
181176 yhat_clean = pd .DataFrame (yhat , index = yhat .index , columns = ["yhat" ])
182177 return yhat_clean ['yhat' ].tolist ()
183- elif model_name == AUTOMLX and series_id in model_object :
178+ elif model_name == SupportedModels . AutoMLX and series_id in model_object :
184179 # automlx model
185180 model = model_object [series_id ]
186181 x_pred = future_df .drop (target_cat_col , axis = 1 )
0 commit comments