@@ -247,17 +247,19 @@ def _generate_report(self):
247247 self .explain_model ()
248248
249249 global_explanation_section = None
250- if self .spec .explanations_accuracy_mode != SpeedAccuracyMode .AUTOMLX :
251- # Convert the global explanation data to a DataFrame
252- global_explanation_df = pd .DataFrame (self .global_explanation )
253250
254- self .formatted_global_explanation = (
255- global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
256- )
257- self .formatted_global_explanation = self .formatted_global_explanation .rename (
251+ # Convert the global explanation data to a DataFrame
252+ global_explanation_df = pd .DataFrame (self .global_explanation )
253+
254+ self .formatted_global_explanation = (
255+ global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
256+ )
257+ self .formatted_global_explanation = (
258+ self .formatted_global_explanation .rename (
258259 {self .spec .datetime_column .name : ForecastOutputColumns .DATE },
259260 axis = 1 ,
260261 )
262+ )
261263
262264 aggregate_local_explanations = pd .DataFrame ()
263265 for s_id , local_ex_df in self .local_explanation .items ():
@@ -269,11 +271,15 @@ def _generate_report(self):
269271 self .formatted_local_explanation = aggregate_local_explanations
270272
271273 if not self .target_cat_col :
272- self .formatted_global_explanation = self .formatted_global_explanation .rename (
273- {"Series 1" : self .original_target_column },
274- axis = 1 ,
274+ self .formatted_global_explanation = (
275+ self .formatted_global_explanation .rename (
276+ {"Series 1" : self .original_target_column },
277+ axis = 1 ,
278+ )
279+ )
280+ self .formatted_local_explanation .drop (
281+ "Series" , axis = 1 , inplace = True
275282 )
276- self .formatted_local_explanation .drop ("Series" , axis = 1 , inplace = True )
277283
278284 # Create a markdown section for the global explainability
279285 global_explanation_section = rc .Block (
@@ -422,7 +428,9 @@ def explain_model(self):
422428 # Use the MLExplainer class from AutoMLx to generate explanations
423429 explainer = automlx .MLExplainer (
424430 self .models [s_id ],
425- self .datasets .additional_data .get_data_for_series (series_id = s_id )
431+ self .datasets .additional_data .get_data_for_series (
432+ series_id = s_id
433+ )
426434 .drop (self .spec .datetime_column .name , axis = 1 )
427435 .head (- self .spec .horizon )
428436 if self .spec .additional_data
@@ -433,7 +441,9 @@ def explain_model(self):
433441
434442 # Generate explanations for the forecast
435443 explanations = explainer .explain_prediction (
436- X = self .datasets .additional_data .get_data_for_series (series_id = s_id )
444+ X = self .datasets .additional_data .get_data_for_series (
445+ series_id = s_id
446+ )
437447 .drop (self .spec .datetime_column .name , axis = 1 )
438448 .tail (self .spec .horizon )
439449 if self .spec .additional_data
@@ -445,7 +455,9 @@ def explain_model(self):
445455 explanations_df = pd .concat (
446456 [exp .to_dataframe () for exp in explanations ]
447457 )
448- explanations_df ["row" ] = explanations_df .groupby ("Feature" ).cumcount ()
458+ explanations_df ["row" ] = explanations_df .groupby (
459+ "Feature"
460+ ).cumcount ()
449461 explanations_df = explanations_df .pivot (
450462 index = "row" , columns = "Feature" , values = "Attribution"
451463 )
@@ -454,14 +466,17 @@ def explain_model(self):
454466 # Store the explanations in the local_explanation dictionary
455467 self .local_explanation [s_id ] = explanations_df
456468
457- self .global_explanation [s_id ] = dict (zip (
458- data_i .columns [1 :],
459- np .average (np .absolute (explanations_df [:, 1 :]), axis = 0 ),
469+ self .global_explanation [s_id ] = dict (
470+ zip (
471+ self .local_explanation [s_id ].columns ,
472+ np .nanmean ((self .local_explanation [s_id ]), axis = 0 ),
460473 )
461474 )
462475 else :
463476 # Fall back to the default explanation generation method
464477 super ().explain_model ()
465478 except Exception as e :
466- logger .warning (f"Failed to generate explanations for series { s_id } with error: { e } ." )
479+ logger .warning (
480+ f"Failed to generate explanations for series { s_id } with error: { e } ."
481+ )
467482 logger .debug (f"Full Traceback: { traceback .format_exc ()} " )
0 commit comments