2525from sklearn .base import RegressorMixin
2626
2727from causalpy .custom_exceptions import BadIndexException
28- from causalpy .plot_utils import plot_xY
28+ from causalpy .plot_utils import get_hdi_to_df , plot_xY
2929from causalpy .pymc_models import PyMCModel
3030from causalpy .utils import round_num
3131
@@ -123,7 +123,7 @@ def summary(self, round_to=None) -> None:
123123 print (f"Formula: { self .formula } " )
124124 self .print_coefficients (round_to )
125125
126- def bayesian_plot (
126+ def _bayesian_plot (
127127 self , round_to = None , ** kwargs
128128 ) -> tuple [plt .Figure , List [plt .Axes ]]:
129129 """
@@ -231,7 +231,7 @@ def bayesian_plot(
231231
232232 return fig , ax
233233
234- def ols_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
234+ def _ols_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
235235 """
236236 Plot the results
237237
@@ -303,6 +303,70 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303
304304 return (fig , ax )
305305
306+ def get_plot_data_bayesian (self , hdi_prob : float = 0.94 ) -> pd .DataFrame :
307+ """
308+ Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309+
310+ :param hdi_prob:
311+ Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
312+ """
313+ if isinstance (self .model , PyMCModel ):
314+ hdi_pct = int (round (hdi_prob * 100 ))
315+
316+ pred_lower_col = f"pred_hdi_lower_{ hdi_pct } "
317+ pred_upper_col = f"pred_hdi_upper_{ hdi_pct } "
318+ impact_lower_col = f"impact_hdi_lower_{ hdi_pct } "
319+ impact_upper_col = f"impact_hdi_upper_{ hdi_pct } "
320+
321+ pre_data = self .datapre .copy ()
322+ post_data = self .datapost .copy ()
323+
324+ pre_data ["prediction" ] = (
325+ az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
326+ .mean ("sample" )
327+ .values
328+ )
329+ post_data ["prediction" ] = (
330+ az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
331+ .mean ("sample" )
332+ .values
333+ )
334+ pre_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
335+ self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
336+ ).set_index (pre_data .index )
337+ post_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
338+ self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
339+ ).set_index (post_data .index )
340+
341+ pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
342+ post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
343+ pre_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
344+ self .pre_impact , hdi_prob = hdi_prob
345+ ).set_index (pre_data .index )
346+ post_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
347+ self .post_impact , hdi_prob = hdi_prob
348+ ).set_index (post_data .index )
349+
350+ self .plot_data = pd .concat ([pre_data , post_data ])
351+
352+ return self .plot_data
353+ else :
354+ raise ValueError ("Unsupported model type" )
355+
356+ def get_plot_data_ols (self ) -> pd .DataFrame :
357+ """
358+ Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
359+ """
360+ pre_data = self .datapre .copy ()
361+ post_data = self .datapost .copy ()
362+ pre_data ["prediction" ] = self .pre_pred
363+ post_data ["prediction" ] = self .post_pred
364+ pre_data ["impact" ] = self .pre_impact
365+ post_data ["impact" ] = self .post_impact
366+ self .plot_data = pd .concat ([pre_data , post_data ])
367+
368+ return self .plot_data
369+
306370
307371class InterruptedTimeSeries (PrePostFit ):
308372 """
@@ -382,7 +446,7 @@ class SyntheticControl(PrePostFit):
382446 supports_ols = True
383447 supports_bayes = True
384448
385- def bayesian_plot (self , * args , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
449+ def _bayesian_plot (self , * args , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
386450 """
387451 Plot the results
388452
@@ -393,7 +457,7 @@ def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
393457 Whether to plot the control units as well. Defaults to False.
394458 """
395459 # call the super class method
396- fig , ax = super ().bayesian_plot (* args , ** kwargs )
460+ fig , ax = super ()._bayesian_plot (* args , ** kwargs )
397461
398462 # additional plotting functionality for the synthetic control experiment
399463 plot_predictors = kwargs .get ("plot_predictors" , False )
0 commit comments