@@ -105,6 +105,7 @@ def generate_report(self):
105105 self .datasets ,
106106 self .forecast_output ,
107107 self .spec .datetime_column .name ,
108+ self .original_target_column ,
108109 target_col = self .forecast_col_name ,
109110 )
110111 else :
@@ -125,6 +126,7 @@ def generate_report(self):
125126 target_columns = self .target_columns ,
126127 test_filename = self .spec .test_data .url ,
127128 output = self .forecast_output ,
129+ original_target_column = self .original_target_column ,
128130 target_col = self .forecast_col_name ,
129131 elapsed_time = elapsed_time ,
130132 )
@@ -144,12 +146,13 @@ def generate_report(self):
144146
145147 title_text = dp .Text ("# Forecast Report" )
146148
147- md_columns = " * " .join ([f"{ x } \n " for x in self .target_columns ])
149+ md_columns = " * " .join ([f"{ utils .convert_target (x ,self .original_target_column )} \n "
150+ for x in self .target_columns ])
148151 first_10_rows_blocks = [
149152 dp .DataTable (
150153 df .head (10 ).rename ({col : self .spec .target_column }, axis = 1 ),
151154 caption = "Start" ,
152- label = col ,
155+ label = utils . convert_target ( col , self . original_target_column ) ,
153156 )
154157 for col , df in self .full_data_dict .items ()
155158 ]
@@ -158,7 +161,7 @@ def generate_report(self):
158161 dp .DataTable (
159162 df .tail (10 ).rename ({col : self .spec .target_column }, axis = 1 ),
160163 caption = "End" ,
161- label = col ,
164+ label = utils . convert_target ( col , self . original_target_column ) ,
162165 )
163166 for col , df in self .full_data_dict .items ()
164167 ]
@@ -167,7 +170,7 @@ def generate_report(self):
167170 dp .DataTable (
168171 df .rename ({col : self .spec .target_column }, axis = 1 ).describe (),
169172 caption = "Summary Statistics" ,
170- label = col ,
173+ label = utils . convert_target ( col , self . original_target_column ) ,
171174 )
172175 for col , df in self .full_data_dict .items ()
173176 ]
@@ -254,6 +257,7 @@ def generate_report(self):
254257 forecast_sec = utils .get_forecast_plots (
255258 self .forecast_output ,
256259 self .target_columns ,
260+ self .original_target_column ,
257261 horizon = self .spec .horizon ,
258262 test_data = test_data ,
259263 ci_interval_width = self .spec .confidence_interval_width ,
@@ -280,7 +284,7 @@ def generate_report(self):
280284 )
281285
282286 def _test_evaluate_metrics (
283- self , target_columns , test_filename , output , target_col = "yhat" , elapsed_time = 0
287+ self , target_columns , test_filename , output , original_target_column , target_col = "yhat" , elapsed_time = 0
284288 ):
285289 total_metrics = pd .DataFrame ()
286290 summary_metrics = pd .DataFrame ()
@@ -335,7 +339,7 @@ def _test_evaluate_metrics(
335339 metrics_df = utils ._build_metrics_df (
336340 y_true = y_true [- self .spec .horizon :],
337341 y_pred = y_pred [- self .spec .horizon :],
338- column_name = target_column_i ,
342+ column_name = utils . convert_target ( target_column_i , original_target_column ) ,
339343 )
340344 total_metrics = pd .concat ([total_metrics , metrics_df ], axis = 1 )
341345 else :
@@ -675,7 +679,7 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict:
675679 f"No explanations generated. Ensure that additional data has been provided."
676680 )
677681 else :
678- self .global_explanation [series_id ] = dict (
682+ self .global_explanation [utils . convert_target ( series_id , self . original_target_column ) ] = dict (
679683 zip (
680684 data_trimmed .columns [1 :],
681685 np .average (np .absolute (kernel_explnr_vals [:, 1 :]), axis = 0 ),
@@ -724,4 +728,4 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
724728 ["series_id" , self .spec .target_column ], axis = 1 , inplace = True
725729 )
726730
727- self .local_explanation [series_id ] = local_kernel_explnr_df
731+ self .local_explanation [utils . convert_target ( series_id , self . original_target_column ) ] = local_kernel_explnr_df
0 commit comments