@@ -256,6 +256,7 @@ def _generate_report(self):
256256 self .outputs [s_id ], include_legend = True
257257 ),
258258 series_ids = series_ids ,
259+ target_category_column = self .target_cat_col
259260 )
260261 section_1 = rc .Block (
261262 rc .Heading ("Forecast Overview" , level = 2 ),
@@ -268,6 +269,7 @@ def _generate_report(self):
268269 sec2 = _select_plot_list (
269270 lambda s_id : self .models [s_id ].plot_components (self .outputs [s_id ]),
270271 series_ids = series_ids ,
272+ target_category_column = self .target_cat_col
271273 )
272274 section_2 = rc .Block (
273275 rc .Heading ("Forecast Broken Down by Trend Component" , level = 2 ), sec2
@@ -281,7 +283,9 @@ def _generate_report(self):
281283 sec3_figs [s_id ].gca (), self .models [s_id ], self .outputs [s_id ]
282284 )
283285 sec3 = _select_plot_list (
284- lambda s_id : sec3_figs [s_id ], series_ids = series_ids
286+ lambda s_id : sec3_figs [s_id ],
287+ series_ids = series_ids ,
288+ target_category_column = self .target_cat_col
285289 )
286290 section_3 = rc .Block (rc .Heading ("Forecast Changepoints" , level = 2 ), sec3 )
287291
@@ -295,7 +299,7 @@ def _generate_report(self):
295299 pd .Series (
296300 m .seasonalities ,
297301 index = pd .Index (m .seasonalities .keys (), dtype = "object" ),
298- name = s_id ,
302+ name = s_id if self . target_cat_col else self . original_target_column ,
299303 dtype = "object" ,
300304 )
301305 )
@@ -316,15 +320,6 @@ def _generate_report(self):
316320 global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
317321 )
318322
319- # Create a markdown section for the global explainability
320- global_explanation_section = rc .Block (
321- rc .Heading ("Global Explanation of Models" , level = 2 ),
322- rc .Text (
323- "The following tables provide the feature attribution for the global explainability."
324- ),
325- rc .DataTable (self .formatted_global_explanation , index = True ),
326- )
327-
328323 aggregate_local_explanations = pd .DataFrame ()
329324 for s_id , local_ex_df in self .local_explanation .items ():
330325 local_ex_df_copy = local_ex_df .copy ()
@@ -334,17 +329,33 @@ def _generate_report(self):
334329 )
335330 self .formatted_local_explanation = aggregate_local_explanations
336331
332+ if not self .target_cat_col :
333+ self .formatted_global_explanation = self .formatted_global_explanation .rename (
334+ {"Series 1" : self .original_target_column },
335+ axis = 1 ,
336+ )
337+ self .formatted_local_explanation .drop ("Series" , axis = 1 , inplace = True )
338+
339+ # Create a markdown section for the global explainability
340+ global_explanation_section = rc .Block (
341+ rc .Heading ("Global Explanation of Models" , level = 2 ),
342+ rc .Text (
343+ "The following tables provide the feature attribution for the global explainability."
344+ ),
345+ rc .DataTable (self .formatted_global_explanation , index = True ),
346+ )
347+
337348 blocks = [
338349 rc .DataTable (
339350 local_ex_df .div (local_ex_df .abs ().sum (axis = 1 ), axis = 0 ) * 100 ,
340- label = s_id ,
351+ label = s_id if self . target_cat_col else None ,
341352 index = True ,
342353 )
343354 for s_id , local_ex_df in self .local_explanation .items ()
344355 ]
345356 local_explanation_section = rc .Block (
346357 rc .Heading ("Local Explanation of Models" , level = 2 ),
347- rc .Select (blocks = blocks ),
358+ rc .Select (blocks = blocks ) if len ( blocks ) > 1 else blocks [ 0 ] ,
348359 )
349360
350361 # Append the global explanation text and section to the "all_sections" list
0 commit comments