@@ -249,17 +249,18 @@ def _generate_report(self):
249249 self .explain_model ()
250250
251251 global_explanation_section = None
252- if self .spec .explanations_accuracy_mode != SpeedAccuracyMode .AUTOMLX :
253- # Convert the global explanation data to a DataFrame
254- global_explanation_df = pd .DataFrame (self .global_explanation )
255252
256- self .formatted_global_explanation = (
257- global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
258- )
259- self .formatted_global_explanation = self .formatted_global_explanation .rename (
260- {self .spec .datetime_column .name : ForecastOutputColumns .DATE },
261- axis = 1 ,
262- )
253+ # Convert the global explanation data to a DataFrame
254+ global_explanation_df = pd .DataFrame (self .global_explanation )
255+
256+ self .formatted_global_explanation = (
257+ global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
258+ )
259+
260+ self .formatted_global_explanation .rename (
261+ columns = {self .spec .datetime_column .name : ForecastOutputColumns .DATE },
262+ inplace = True ,
263+ )
263264
264265 aggregate_local_explanations = pd .DataFrame ()
265266 for s_id , local_ex_df in self .local_explanation .items ():
@@ -428,7 +429,9 @@ def explain_model(self):
428429 # Use the MLExplainer class from AutoMLx to generate explanations
429430 explainer = automlx .MLExplainer (
430431 self .models [s_id ]["model" ],
431- self .datasets .additional_data .get_data_for_series (series_id = s_id )
432+ self .datasets .additional_data .get_data_for_series (
433+ series_id = s_id
434+ )
432435 .drop (self .spec .datetime_column .name , axis = 1 )
433436 .head (- self .spec .horizon )
434437 if self .spec .additional_data
@@ -463,6 +466,13 @@ def explain_model(self):
463466
464467 # Store the explanations in the local_explanation dictionary
465468 self .local_explanation [s_id ] = explanations_df
469+
470+ self .global_explanation [s_id ] = dict (
471+ zip (
472+ self .local_explanation [s_id ].columns ,
473+ np .nanmean ((self .local_explanation [s_id ]), axis = 0 ),
474+ )
475+ )
466476 else :
467477 # Fall back to the default explanation generation method
468478 super ().explain_model ()
0 commit comments