@@ -259,6 +259,7 @@ def setup_faulty_rossman():
259259 additional_data_path = f"{ data_folder } /rs_2_add_encoded.csv"
260260 return historical_data_path , additional_data_path
261261
262+
262263def setup_small_rossman ():
263264 curr_dir = pathlib .Path (__file__ ).parent .resolve ()
264265 data_folder = f"{ curr_dir } /../data/"
@@ -396,7 +397,7 @@ def test_0_series(operator_setup, model):
396397 historical_data_path = historical_data_path ,
397398 additional_data_path = additional_data_path ,
398399 test_data_path = test_data_path ,
399- preprocessing = {"enabled" : False }
400+ preprocessing = {"enabled" : False },
400401 )
401402 with pytest .raises (DataMismatchError ):
402403 run_yaml (
@@ -465,36 +466,36 @@ def test_disabling_outlier_treatment(operator_setup):
465466 axis = 1 ,
466467 )
467468 outliers = [1000 , - 800 ]
468- hist_data_0 .at [40 , ' Sales' ] = outliers [0 ]
469- hist_data_0 .at [75 , ' Sales' ] = outliers [1 ]
469+ hist_data_0 .at [40 , " Sales" ] = outliers [0 ]
470+ hist_data_0 .at [75 , " Sales" ] = outliers [1 ]
470471 historical_data_path , additional_data_path , test_data_path = setup_artificial_data (
471472 tmpdirname , hist_data_0
472473 )
473474
474475 yaml_i , output_data_path = populate_yaml (
475- tmpdirname = tmpdirname ,
476- model = "arima" ,
477- historical_data_path = historical_data_path
476+ tmpdirname = tmpdirname , model = "arima" , historical_data_path = historical_data_path
478477 )
479478 yaml_i ["spec" ].pop ("target_category_columns" )
480479 yaml_i ["spec" ].pop ("additional_data" )
481480
482481 # running default pipeline where outlier will be treated
483482 run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path )
484483 forecast_without_outlier = pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )
485- input_vals_without_outlier = set (forecast_without_outlier [' input_value' ])
484+ input_vals_without_outlier = set (forecast_without_outlier [" input_value" ])
486485 assert all (
487- item not in input_vals_without_outlier for item in outliers ), "forecast file should not contain any outliers"
486+ item not in input_vals_without_outlier for item in outliers
487+ ), "forecast file should not contain any outliers"
488488
489489 # switching off outlier_treatment
490490 preprocessing_steps = {"missing_value_imputation" : True , "outlier_treatment" : False }
491491 preprocessing = {"enabled" : True , "steps" : preprocessing_steps }
492492 yaml_i ["spec" ]["preprocessing" ] = preprocessing
493493 run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path )
494494 forecast_with_outlier = pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )
495- input_vals_with_outlier = set (forecast_with_outlier [' input_value' ])
495+ input_vals_with_outlier = set (forecast_with_outlier [" input_value" ])
496496 assert all (
497- item in input_vals_with_outlier for item in outliers ), "forecast file should contain all the outliers"
497+ item in input_vals_with_outlier for item in outliers
498+ ), "forecast file should contain all the outliers"
498499
499500
500501@pytest .mark .parametrize ("model" , MODELS )
@@ -529,7 +530,7 @@ def split_df(df):
529530 historical_data_path = historical_data_path ,
530531 additional_data_path = additional_data_path ,
531532 test_data_path = test_data_path ,
532- preprocessing = {"enabled" : True , "steps" : preprocessing_steps }
533+ preprocessing = {"enabled" : True , "steps" : preprocessing_steps },
533534 )
534535 with pytest .raises (DataMismatchError ):
535536 # 4 columns in historical data, but only 1 cat col specified
@@ -561,8 +562,8 @@ def test_all_series_failure(model):
561562 )
562563 preprocessing_steps = {"missing_value_imputation" : True , "outlier_treatment" : False }
563564 yaml_i ["spec" ]["model" ] = model
564- yaml_i [' spec' ][ ' horizon' ] = 10
565- yaml_i [' spec' ][ ' preprocessing' ] = preprocessing_steps
565+ yaml_i [" spec" ][ " horizon" ] = 10
566+ yaml_i [" spec" ][ " preprocessing" ] = preprocessing_steps
566567 if yaml_i ["spec" ].get ("additional_data" ) is not None and model != "autots" :
567568 yaml_i ["spec" ]["generate_explanations" ] = True
568569 if model == "autots" :
@@ -571,14 +572,15 @@ def test_all_series_failure(model):
571572 yaml_i ["spec" ]["model_kwargs" ] = {"time_budget" : 1 }
572573
573574 module_to_patch = {
574- "arima" : ' pmdarima.auto_arima' ,
575- "autots" : ' autots.AutoTS' ,
576- "automlx" : ' automlx.Pipeline' ,
577- "prophet" : ' prophet.Prophet' ,
578- "neuralprophet" : ' neuralprophet.NeuralProphet'
575+ "arima" : " pmdarima.auto_arima" ,
576+ "autots" : " autots.AutoTS" ,
577+ "automlx" : " automlx.Pipeline" ,
578+ "prophet" : " prophet.Prophet" ,
579+ "neuralprophet" : " neuralprophet.NeuralProphet" ,
579580 }
580- with patch (module_to_patch [model ], side_effect = Exception ("Custom exception message" )):
581-
581+ with patch (
582+ module_to_patch [model ], side_effect = Exception ("Custom exception message" )
583+ ):
582584 run (yaml_i , backend = "operator.local" , debug = False )
583585
584586 report_path = f"{ output_data_path } /report.html"
@@ -588,17 +590,26 @@ def test_all_series_failure(model):
588590 assert os .path .exists (error_path ), f"Error file not found at { error_path } "
589591
590592 # Additionally, you can read the content of the error.json and assert its content
591- with open (error_path , 'r' ) as error_file :
593+ with open (error_path , "r" ) as error_file :
592594 error_content = json .load (error_file )
593- assert "Custom exception message" in error_content ["1" ]["error" ], "Error message mismatch"
594- assert "Custom exception message" in error_content ["13" ]["error" ], "Error message mismatch"
595+ assert (
596+ "Custom exception message" in error_content ["1" ]["error" ]
597+ ), "Error message mismatch"
598+ assert (
599+ "Custom exception message" in error_content ["13" ]["error" ]
600+ ), "Error message mismatch"
595601
596602 if yaml_i ["spec" ]["generate_explanations" ]:
597603 global_fn = f"{ tmpdirname } /results/global_explanation.csv"
598- assert os .path .exists (global_fn ), f"Global explanation file not found at { report_path } "
604+ assert os .path .exists (
605+ global_fn
606+ ), f"Global explanation file not found at { report_path } "
599607
600608 local_fn = f"{ tmpdirname } /results/local_explanation.csv"
601- assert os .path .exists (local_fn ), f"Local explanation file not found at { report_path } "
609+ assert os .path .exists (
610+ local_fn
611+ ), f"Local explanation file not found at { report_path } "
612+
602613
603614@pytest .mark .parametrize ("model" , MODELS )
604615def test_arima_automlx_errors (operator_setup , model ):
@@ -611,29 +622,38 @@ def test_arima_automlx_errors(operator_setup, model):
611622 )
612623
613624 """
614- Arima was failing for constant trend when there are constant columns and when there are boolean columns .
615- We added label encoding for boolean and are dropping columns with constant value for arima with constant trend.
625+ Arima was failing for constant trend when there are constant columns and when there are boolean columns .
626+ We added label encoding for boolean and are dropping columns with constant value for arima with constant trend.
616627 This test checks that report, metrics, explanations are generated for this case.
617628 """
618629
619630 """
620- series 13 in this data has missing dates and automlx fails for this with DatetimeIndex error. This test checks that
631+ series 13 in this data has missing dates and automlx fails for this with DatetimeIndex error. This test checks that
621632 outputs get generated and that error is shown in errors.json
622633 """
623634
624635 """
625- explanations generation is failing when boolean columns are passed.
626- TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced
636+ explanations generation is failing when boolean columns are passed.
637+ TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced
627638 any supported types according to the casting rule ''safe''
628639 Added label encoding before passing data to explainer
629640 """
630641 preprocessing_steps = {"missing_value_imputation" : True , "outlier_treatment" : False }
631- yaml_i ['spec' ]['horizon' ] = 10
632- yaml_i ['spec' ]['preprocessing' ] = preprocessing_steps
633- yaml_i ['spec' ]['generate_explanations' ] = True
634- yaml_i ['spec' ]['model' ] = model
642+ yaml_i ["spec" ]["horizon" ] = 10
643+ yaml_i ["spec" ]["preprocessing" ] = preprocessing_steps
644+ yaml_i ["spec" ]["generate_explanations" ] = True
645+ yaml_i ["spec" ]["model" ] = model
646+ if model == "autots" :
647+ yaml_i ["spec" ]["model_kwargs" ] = {"model_list" : "superfast" }
648+ if model == "automlx" :
649+ yaml_i ["spec" ]["model_kwargs" ] = {"time_budget" : 1 }
635650
636- run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path , test_metrics_check = False )
651+ run_yaml (
652+ tmpdirname = tmpdirname ,
653+ yaml_i = yaml_i ,
654+ output_data_path = output_data_path ,
655+ test_metrics_check = False ,
656+ )
637657
638658 report_path = f"{ tmpdirname } /results/report.html"
639659 assert os .path .exists (report_path ), f"Report file not found at { report_path } "
@@ -642,23 +662,28 @@ def test_arima_automlx_errors(operator_setup, model):
642662 assert os .path .exists (forecast_path ), f"Forecast file not found at { report_path } "
643663 assert not pd .read_csv (forecast_path ).empty
644664
645-
646665 error_path = f"{ tmpdirname } /results/errors.json"
647666 if model == "arima" :
648667 assert not os .path .exists (error_path ), f"Error file not found at { error_path } "
649668 elif model == "automlx" :
650669 assert os .path .exists (error_path ), f"Error file not found at { error_path } "
651- with open (error_path , 'r' ) as error_file :
670+ with open (error_path , "r" ) as error_file :
652671 error_content = json .load (error_file )
653- assert "Input data does not have a consistent (in terms of diff) DatetimeIndex." in error_content ["13" ][
654- "error" ], "Error message mismatch"
672+ assert (
673+ "Input data does not have a consistent (in terms of diff) DatetimeIndex."
674+ in error_content ["13" ]["error" ]
675+ ), "Error message mismatch"
655676
656677 if model != "autots" :
657678 global_fn = f"{ tmpdirname } /results/global_explanation.csv"
658- assert os .path .exists (global_fn ), f"Global explanation file not found at { report_path } "
679+ assert os .path .exists (
680+ global_fn
681+ ), f"Global explanation file not found at { report_path } "
659682
660683 local_fn = f"{ tmpdirname } /results/local_explanation.csv"
661- assert os .path .exists (local_fn ), f"Local explanation file not found at { report_path } "
684+ assert os .path .exists (
685+ local_fn
686+ ), f"Local explanation file not found at { report_path } "
662687
663688 glb_expl = pd .read_csv (global_fn , index_col = 0 )
664689 loc_expl = pd .read_csv (local_fn )
@@ -680,13 +705,20 @@ def test_date_format(operator_setup, model):
680705 historical_data_path = historical_data_path ,
681706 additional_data_path = additional_data_path ,
682707 )
683- yaml_i [' spec' ][ ' horizon' ] = 10
708+ yaml_i [" spec" ][ " horizon" ] = 10
684709 yaml_i ["spec" ]["model" ] = model
685710 if model == "autots" :
686711 yaml_i ["spec" ]["model_kwargs" ] = {"model_list" : "superfast" }
687712
688- run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path , test_metrics_check = False )
689- assert pd .read_csv (additional_data_path )['Date' ].equals (pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )['Date' ])
713+ run_yaml (
714+ tmpdirname = tmpdirname ,
715+ yaml_i = yaml_i ,
716+ output_data_path = output_data_path ,
717+ test_metrics_check = False ,
718+ )
719+ assert pd .read_csv (additional_data_path )["Date" ].equals (
720+ pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )["Date" ]
721+ )
690722
691723
692724if __name__ == "__main__" :
0 commit comments