@@ -303,14 +303,14 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303
304304 return (fig , ax )
305305
306- def get_plot_data_bayesian (self ) -> pd .DataFrame :
306+ def get_plot_data_bayesian (self , hdi_prob = 0.94 ) -> pd .DataFrame :
307307 """
308308 Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309309 """
310310 if isinstance (self .model , PyMCModel ):
311311 pre_data = self .datapre .copy ()
312312 post_data = self .datapost .copy ()
313- # PREDICTIONS
313+
314314 pre_data ["prediction" ] = (
315315 az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
316316 .mean ("sample" )
@@ -321,15 +321,13 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
321321 .mean ("sample" )
322322 .values
323323 )
324- # HDI
325- pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu )
326- post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu )
327- # IMPACT
324+ pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob )
325+ post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob )
326+
328327 pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
329328 post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
330- # HDI IMPACT
331- pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact )
332- post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact )
329+ pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
330+ post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob )
333331
334332 self .data_plot = pd .concat ([pre_data , post_data ])
335333
0 commit comments