8989 parameters_short .append ((model , dataset_i ))
9090
9191
92+ def verify_explanations (global_fn , local_fn , yaml_i , additional_cols ):
93+ glb_expl = pd .read_csv (global_fn , index_col = 0 )
94+ loc_expl = pd .read_csv (local_fn )
95+ assert loc_expl .shape [0 ] == PERIODS
96+ for x in [yaml_i ["spec" ]["datetime_column" ]["name" ], "Series" ]:
97+ assert x in set (loc_expl .columns )
98+ for x in additional_cols :
99+ assert x in set (loc_expl .columns )
100+ assert x in set (glb_expl .index )
101+ assert "Series 1" in set (glb_expl .columns )
102+
103+
92104@pytest .mark .parametrize ("model, dataset_name" , parameters_short )
93105def test_load_datasets (model , dataset_name ):
94106 if model == "automlx" and dataset_name == "WeatherDataset" :
@@ -97,6 +109,7 @@ def test_load_datasets(model, dataset_name):
97109 datetime_col = dataset_i .time_index .name
98110
99111 columns = dataset_i .components
112+ additional_cols = []
100113 target = dataset_i [columns [0 ]][:- PERIODS ]
101114 test = dataset_i [columns [0 ]][- PERIODS :]
102115
@@ -145,7 +158,7 @@ def test_load_datasets(model, dataset_name):
145158 yaml_i ["spec" ]["target_column" ] = columns [0 ]
146159 yaml_i ["spec" ]["datetime_column" ]["name" ] = datetime_col
147160 yaml_i ["spec" ]["horizon" ] = PERIODS
148- if yaml_i ["spec" ].get ("additional_data" ) is not None and model != "automlx " :
161+ if yaml_i ["spec" ].get ("additional_data" ) is not None and model != "autots " :
149162 yaml_i ["spec" ]["generate_explanations" ] = True
150163 if generate_train_metrics :
151164 yaml_i ["spec" ]["generate_metrics" ] = generate_train_metrics
@@ -164,11 +177,13 @@ def test_load_datasets(model, dataset_name):
164177 # sleep(0.1)
165178 run (yaml_i , backend = "operator.local" , debug = False )
166179 subprocess .run (f"ls -a { output_data_path } " , shell = True )
167- if yaml_i ["spec" ]["generate_explanations" ] and model != "autots" :
168- glb_expl = pd .read_csv (f"{ tmpdirname } /results/global_explanation.csv" )
169- print (glb_expl )
170- loc_expl = pd .read_csv (f"{ tmpdirname } /results/local_explanation.csv" )
171- print (loc_expl )
180+ if yaml_i ["spec" ]["generate_explanations" ]:
181+ verify_explanations (
182+ global_fn = f"{ tmpdirname } /results/global_explanation.csv" ,
183+ local_fn = f"{ tmpdirname } /results/local_explanation.csv" ,
184+ yaml_i = yaml_i ,
185+ additional_cols = additional_cols ,
186+ )
172187
173188 test_metrics = pd .read_csv (f"{ tmpdirname } /results/test_metrics.csv" )
174189 print (test_metrics )
0 commit comments