@@ -929,19 +929,24 @@ def test_generate_files(operator_setup, model):
929929 )
930930
931931 yaml_i = TEMPLATE_YAML .copy ()
932- yaml_i ["spec" ]["horizon" ] = 10
932+ yaml_i ["spec" ]["horizon" ] = 3
933933 yaml_i ["spec" ]["model" ] = model
934934 yaml_i ["spec" ]["historical_data" ] = {"format" : "pandas" }
935+ yaml_i ["spec" ]["additional_data" ] = {"format" : "pandas" }
935936 yaml_i ["spec" ]["target_column" ] = TARGET_COL .name
936937 yaml_i ["spec" ]["datetime_column" ]["name" ] = HISTORICAL_DATETIME_COL .name
937- yaml_i ["spec" ]["report_title" ] = "Skibidi ADS Skibidi"
938938 yaml_i ["spec" ]["output_directory" ]["url" ] = operator_setup
939- yaml_i ["spec" ]["generate_explanations_file " ] = False
939+ yaml_i ["spec" ]["generate_explanation_files " ] = False
940940 yaml_i ["spec" ]["generate_forecast_file" ] = False
941941 yaml_i ["spec" ]["generate_metrics_file" ] = False
942+ yaml_i ["spec" ]["generate_explanations" ] = True
942943
943944 df = pd .concat ([HISTORICAL_DATETIME_COL [:15 ], TARGET_COL [:15 ]], axis = 1 )
945+ df_add = pd .concat ([HISTORICAL_DATETIME_COL [:18 ], ADD_COLS [:18 ]], axis = 1 )
946+ print (f"df: { df } " )
947+ print (f"df_add: { df_add } " )
944948 yaml_i ["spec" ]["historical_data" ]["data" ] = df
949+ yaml_i ["spec" ]["additional_data" ]["data" ] = df_add
945950 operator_config = ForecastOperatorConfig .from_dict (yaml_i )
946951 results = operate (operator_config )
947952 files = os .listdir (operator_setup )
@@ -952,8 +957,28 @@ def test_generate_files(operator_setup, model):
952957 assert (
953958 "metrics.csv" not in files
954959 ), "Generated metrics file, but `generate_metrics_file` was set False"
960+ assert (
961+ "local_explanations.csv" not in files
962+ ), "Generated metrics file, but `generate_explanation_files` was set False"
963+ assert (
964+ "global_explanations.csv" not in files
965+ ), "Generated metrics file, but `generate_explanation_files` was set False"
955966 assert not results .get_forecast ().empty
956967 assert not results .get_metrics ().empty
968+ assert not results .get_global_explanations ().empty
969+ assert not results .get_local_explanations ().empty
970+
971+ yaml_i ["spec" ].pop ("generate_explanation_files" )
972+ yaml_i ["spec" ].pop ("generate_forecast_file" )
973+ yaml_i ["spec" ].pop ("generate_metrics_file" )
974+ operator_config = ForecastOperatorConfig .from_dict (yaml_i )
975+ results = operate (operator_config )
976+ files = os .listdir (operator_setup )
977+ assert "report.html" in files , "Failed to generate report"
978+ assert "forecast.csv" in files , "Failed to generate forecast file"
979+ assert "metrics.csv" in files , "Failed to generated metrics file"
980+ assert "local_explanation.csv" in files , "Failed to generated local expl file"
981+ assert "global_explanation.csv" in files , "Failed to generated global expl file"
957982
958983
959984if __name__ == "__main__" :
0 commit comments