3232from ads .opctl .operator .cmd import run
3333import os
3434import json
35+ import math
3536
3637NUM_ROWS = 1000
3738NUM_SERIES = 10
@@ -191,6 +192,7 @@ def populate_yaml(
191192 additional_data_path = None ,
192193 test_data_path = None ,
193194 output_data_path = None ,
195+ preprocessing = None ,
194196):
195197 if historical_data_path is None :
196198 historical_data_path , additional_data_path , test_data_path = setup_rossman ()
@@ -210,7 +212,8 @@ def populate_yaml(
210212 yaml_i ["spec" ]["datetime_column" ]["name" ] = "Date"
211213 yaml_i ["spec" ]["target_category_columns" ] = ["Store" ]
212214 yaml_i ["spec" ]["horizon" ] = HORIZON
213-
215+ if preprocessing :
216+ yaml_i ["spec" ]["preprocessing" ] = preprocessing
214217 if generate_train_metrics :
215218 yaml_i ["spec" ]["generate_metrics" ] = generate_train_metrics
216219 if model == "autots" :
@@ -393,6 +396,7 @@ def test_0_series(operator_setup, model):
393396 historical_data_path = historical_data_path ,
394397 additional_data_path = additional_data_path ,
395398 test_data_path = test_data_path ,
399+ preprocessing = {"enabled" : False }
396400 )
397401 with pytest .raises (DataMismatchError ):
398402 run_yaml (
@@ -450,6 +454,49 @@ def test_invalid_dates(operator_setup, model):
450454 )
451455
452456
457+ def test_disabling_outlier_treatment (operator_setup ):
458+ tmpdirname = operator_setup
459+ NUM_ROWS = 100
460+ hist_data_0 = pd .concat (
461+ [
462+ HISTORICAL_DATETIME_COL [: NUM_ROWS - HORIZON ],
463+ TARGET_COL [: NUM_ROWS - HORIZON ],
464+ ],
465+ axis = 1 ,
466+ )
467+ outliers = [1000 , - 800 ]
468+ hist_data_0 .at [40 , 'Sales' ] = outliers [0 ]
469+ hist_data_0 .at [75 , 'Sales' ] = outliers [1 ]
470+ historical_data_path , additional_data_path , test_data_path = setup_artificial_data (
471+ tmpdirname , hist_data_0
472+ )
473+
474+ yaml_i , output_data_path = populate_yaml (
475+ tmpdirname = tmpdirname ,
476+ model = "arima" ,
477+ historical_data_path = historical_data_path
478+ )
479+ yaml_i ["spec" ].pop ("target_category_columns" )
480+ yaml_i ["spec" ].pop ("additional_data" )
481+
482+ # running default pipeline where outlier will be treated
483+ run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path )
484+ forecast_without_outlier = pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )
485+ input_vals_without_outlier = set (forecast_without_outlier ['input_value' ])
486+ assert all (
487+ item not in input_vals_without_outlier for item in outliers ), "forecast file should not contain any outliers"
488+
489+ # switching off outlier_treatment
490+ preprocessing_steps = {"missing_value_imputation" : True , "outlier_treatment" : False }
491+ preprocessing = {"enabled" : True , "steps" : preprocessing_steps }
492+ yaml_i ["spec" ]["preprocessing" ] = preprocessing
493+ run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path )
494+ forecast_with_outlier = pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )
495+ input_vals_with_outlier = set (forecast_with_outlier ['input_value' ])
496+ assert all (
497+ item in input_vals_with_outlier for item in outliers ), "forecast file should contain all the outliers"
498+
499+
453500@pytest .mark .parametrize ("model" , MODELS )
454501def test_2_series (operator_setup , model ):
455502 # Test w and w/o add data
@@ -475,12 +522,14 @@ def split_df(df):
475522 historical_data_path , additional_data_path , test_data_path = setup_artificial_data (
476523 tmpdirname , hist_data , add_data , test_data
477524 )
525+ preprocessing_steps = {"missing_value_imputation" : True , "outlier_treatment" : False }
478526 yaml_i , output_data_path = populate_yaml (
479527 tmpdirname = tmpdirname ,
480528 model = model ,
481529 historical_data_path = historical_data_path ,
482530 additional_data_path = additional_data_path ,
483531 test_data_path = test_data_path ,
532+ preprocessing = {"enabled" : True , "steps" : preprocessing_steps }
484533 )
485534 with pytest .raises (DataMismatchError ):
486535 # 4 columns in historical data, but only 1 cat col specified
0 commit comments