2525from sklearn .base import RegressorMixin
2626
2727from causalpy .custom_exceptions import BadIndexException
28- from causalpy .plot_utils import plot_xY
28+ from causalpy .plot_utils import plot_xY , get_hdi_to_df
2929from causalpy .pymc_models import PyMCModel
3030from causalpy .utils import round_num
3131
@@ -303,19 +303,6 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303
304304 return (fig , ax )
305305
306- # def get_plot_data(self) -> pd.DataFrame:
307- # """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308-
309- # Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310- # depending on the model type.
311- # """
312- # if isinstance(self.model, PyMCModel):
313- # return self.get_plot_data_bayesian()
314- # elif isinstance(self.model, RegressorMixin):
315- # return self.get_plot_data_ols()
316- # else:
317- # raise ValueError("Unsupported model type")
318-
319306 def get_plot_data_bayesian (self ) -> pd .DataFrame :
320307 """
321308 Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
@@ -335,23 +322,14 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
335322 .values
336323 )
337324 # HDI
338- pre_hdi = (
339- az .hdi (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = 0.94 )
340- .to_dataframe ()
341- .unstack (level = "hdi" )
342- .droplevel (0 , axis = 1 )
343- )
344- post_hdi = (
345- az .hdi (self .post_pred ["posterior_predictive" ].mu , hdi_prob = 0.94 )
346- .to_dataframe ()
347- .unstack (level = "hdi" )
348- .droplevel (0 , axis = 1 )
349- )
350- pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = pre_hdi
351- post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = post_hdi
325+ pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu )
326+ post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu )
352327 # IMPACT
353328 pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
354329 post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
330+ # HDI IMPACT
331+ pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact )
332+ post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact )
355333
356334 self .data_plot = pd .concat ([pre_data , post_data ])
357335
0 commit comments