@@ -303,18 +303,18 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303
304304 return (fig , ax )
305305
306- def get_plot_data (self ) -> pd .DataFrame :
307- """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308-
309- Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310- depending on the model type.
311- """
312- if isinstance (self .model , PyMCModel ):
313- return self .get_plot_data_bayesian ()
314- elif isinstance (self .model , RegressorMixin ):
315- return self .get_plot_data_ols ()
316- else :
317- raise ValueError ("Unsupported model type" )
306+ # def get_plot_data(self) -> pd.DataFrame:
307+ # """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308+
309+ # Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310+ # depending on the model type.
311+ # """
312+ # if isinstance(self.model, PyMCModel):
313+ # return self.get_plot_data_bayesian()
314+ # elif isinstance(self.model, RegressorMixin):
315+ # return self.get_plot_data_ols()
316+ # else:
317+ # raise ValueError("Unsupported model type")
318318
319319 def get_plot_data_bayesian (self ) -> pd .DataFrame :
320320 """
@@ -323,29 +323,42 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
323323 if isinstance (self .model , PyMCModel ):
324324 pre_data = self .datapre .copy ()
325325 post_data = self .datapost .copy ()
326+ # PREDICTIONS
326327 pre_data ["prediction" ] = (
327- az .extract (
328- self .pre_pred , group = "posterior_predictive" , var_names = "mu"
329- )
328+ az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
330329 .mean ("sample" )
331330 .values
332331 )
333332 post_data ["prediction" ] = (
334- az .extract (
335- self .post_pred , group = "posterior_predictive" , var_names = "mu"
336- )
333+ az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
337334 .mean ("sample" )
338335 .values
339336 )
337+ # HDI
338+ pre_hdi = (
339+ az .hdi (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = 0.94 )
340+ .to_dataframe ()
341+ .unstack (level = "hdi" )
342+ .droplevel (0 , axis = 1 )
343+ )
344+ post_hdi = (
345+ az .hdi (self .post_pred ["posterior_predictive" ].mu , hdi_prob = 0.94 )
346+ .to_dataframe ()
347+ .unstack (level = "hdi" )
348+ .droplevel (0 , axis = 1 )
349+ )
350+ pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = pre_hdi
351+ post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = post_hdi
352+ # IMPACT
340353 pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
341354 post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
342-
355+
343356 self .data_plot = pd .concat ([pre_data , post_data ])
344357
345358 return self .data_plot
346359 else :
347360 raise ValueError ("Unsupported model type" )
348-
361+
349362 def get_plot_data_ols (self ) -> pd .DataFrame :
350363 """
351364 Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
0 commit comments