@@ -303,7 +303,7 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303
304304 return (fig , ax )
305305
306- def get_plot_data_bayesian (self , hdi_prob = 0.94 ) -> pd .DataFrame :
306+ def get_plot_data_bayesian (self , hdi_prob : float = 0.94 ) -> pd .DataFrame :
307307 """
308308 Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309309 """
@@ -321,17 +321,17 @@ def get_plot_data_bayesian(self, hdi_prob=0.94) -> pd.DataFrame:
321321 .mean ("sample" )
322322 .values
323323 )
324- pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob )
325- post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob )
324+ pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob ). set_index ( pre_data . index )
325+ post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob ). set_index ( post_data . index )
326326
327327 pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
328328 post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
329- pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
330- post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob )
329+ pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob ). set_index ( pre_data . index )
330+ post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob ). set_index ( post_data . index )
331331
332- self .data_plot = pd .concat ([pre_data , post_data ])
332+ self .plot_data = pd .concat ([pre_data , post_data ])
333333
334- return self .data_plot
334+ return self .plot_data
335335 else :
336336 raise ValueError ("Unsupported model type" )
337337
@@ -345,9 +345,9 @@ def get_plot_data_ols(self) -> pd.DataFrame:
345345 post_data ["prediction" ] = self .post_pred
346346 pre_data ["impact" ] = self .pre_impact
347347 post_data ["impact" ] = self .post_impact
348- self .data_plot = pd .concat ([pre_data , post_data ])
348+ self .plot_data = pd .concat ([pre_data , post_data ])
349349
350- return self .data_plot
350+ return self .plot_data
351351
352352
353353class InterruptedTimeSeries (PrePostFit ):
0 commit comments