2323from patsy import build_design_matrices , dmatrices
2424from sklearn .linear_model import LinearRegression as sk_lin_reg
2525
26- from causalpy .custom_exceptions import BadIndexException # NOQA
27- from causalpy .custom_exceptions import DataException , FormulaException
26+ from causalpy .custom_exceptions import (
27+ BadIndexException , # NOQA
28+ DataException ,
29+ FormulaException ,
30+ )
2831from causalpy .plot_utils import plot_xY
29- from causalpy .utils import _is_variable_dummy_coded
32+ from causalpy .utils import _is_variable_dummy_coded , round_num
3033
3134LEGEND_FONT_SIZE = 12
3235az .style .use ("arviz-darkgrid" )
@@ -228,7 +231,7 @@ def _input_validation(self, data, treatment_time):
228231 "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
229232 )
230233
231- def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
234+ def plot (self , counterfactual_label = "Counterfactual" , round_to = None , ** kwargs ):
232235 """
233236 Plot the results
234237 """
@@ -275,8 +278,8 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
275278
276279 ax [0 ].set (
277280 title = f"""
278- Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f }
279- (std = { self .score .r2_std :.3f } )
281+ Pre-intervention Bayesian $R^2$: { round_num ( self .score .r2 , round_to ) }
282+ (std = { round_num ( self .score .r2_std , round_to ) } )
280283 """
281284 )
282285
@@ -580,7 +583,7 @@ def _input_validation(self):
580583 coded. Consisting of 0's and 1's only."""
581584 )
582585
583- def plot (self ):
586+ def plot (self , round_to = None ):
584587 """Plot the results.
585588 Creating the combined mean + HDI legend entries is a bit involved.
586589 """
@@ -658,7 +661,7 @@ def plot(self):
658661 # formatting
659662 ax .set (
660663 xticks = self .x_pred_treatment [self .time_variable_name ].values ,
661- title = self ._causal_impact_summary_stat (),
664+ title = self ._causal_impact_summary_stat (round_to ),
662665 )
663666 ax .legend (
664667 handles = (h_tuple for h_tuple in handles ),
@@ -711,11 +714,14 @@ def _plot_causal_impact_arrow(self, ax):
711714 va = "center" ,
712715 )
713716
714- def _causal_impact_summary_stat (self ) -> str :
717+ def _causal_impact_summary_stat (self , round_to = None ) -> str :
715718 """Computes the mean and 94% credible interval bounds for the causal impact."""
716719 percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
717- ci = "$CI_{94\\ %}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
718- causal_impact = f"{ self .causal_impact .mean ():.2f} , "
720+ ci = (
721+ "$CI_{94\\ %}$"
722+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
723+ )
724+ causal_impact = f"{ round_num (self .causal_impact .mean (), round_to )} , "
719725 return f"Causal impact = { causal_impact + ci } "
720726
721727 def summary (self ) -> None :
@@ -893,7 +899,7 @@ def _is_treated(self, x):
893899 """
894900 return np .greater_equal (x , self .treatment_threshold )
895901
896- def plot (self ):
902+ def plot (self , round_to = None ):
897903 """
898904 Plot the results
899905 """
@@ -918,12 +924,15 @@ def plot(self):
918924 labels = ["Posterior mean" ]
919925
920926 # create strings to compose title
921- title_info = f"{ self .score .r2 :.3f } (std = { self .score .r2_std :.3f } )"
927+ title_info = f"{ round_num ( self .score .r2 , round_to ) } (std = { round_num ( self .score .r2_std , round_to ) } )"
922928 r2 = f"Bayesian $R^2$ on all data = { title_info } "
923929 percentiles = self .discontinuity_at_threshold .quantile ([0.03 , 1 - 0.03 ]).values
924- ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
930+ ci = (
931+ r"$CI_{94\%}$"
932+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
933+ )
925934 discon = f"""
926- Discontinuity at threshold = { self .discontinuity_at_threshold .mean ():.2f } ,
935+ Discontinuity at threshold = { round_num ( self .discontinuity_at_threshold .mean (), round_to ) } ,
927936 """
928937 ax .set (title = r2 + "\n " + discon + ci )
929938 # Intervention line
@@ -1104,7 +1113,7 @@ def _is_treated(self, x):
11041113 """Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
11051114 return np .greater_equal (x , self .kink_point )
11061115
1107- def plot (self ):
1116+ def plot (self , round_to = None ):
11081117 """
11091118 Plot the results
11101119 """
@@ -1129,12 +1138,15 @@ def plot(self):
11291138 labels = ["Posterior mean" ]
11301139
11311140 # create strings to compose title
1132- title_info = f"{ self .score .r2 :.3f } (std = { self .score .r2_std :.3f } )"
1141+ title_info = f"{ round_num ( self .score .r2 , round_to ) } (std = { round_num ( self .score .r2_std , round_to ) } )"
11331142 r2 = f"Bayesian $R^2$ on all data = { title_info } "
11341143 percentiles = self .gradient_change .quantile ([0.03 , 1 - 0.03 ]).values
1135- ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
1144+ ci = (
1145+ r"$CI_{94\%}$"
1146+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
1147+ )
11361148 grad_change = f"""
1137- Change in gradient = { self .gradient_change .mean ():.2f } ,
1149+ Change in gradient = { round_num ( self .gradient_change .mean (), round_to ) } ,
11381150 """
11391151 ax .set (title = r2 + "\n " + grad_change + ci )
11401152 # Intervention line
@@ -1292,7 +1304,7 @@ def _input_validation(self) -> None:
12921304 """
12931305 )
12941306
1295- def plot (self ):
1307+ def plot (self , round_to = None ):
12961308 """Plot the results"""
12971309 fig , ax = plt .subplots (
12981310 2 , 1 , figsize = (7 , 9 ), gridspec_kw = {"height_ratios" : [3 , 1 ]}
@@ -1339,18 +1351,21 @@ def plot(self):
13391351 )
13401352
13411353 # Plot estimated caual impact / treatment effect
1342- az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ])
1354+ az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ], round_to = round_to )
13431355 ax [1 ].set (title = "Estimated treatment effect" )
13441356 return fig , ax
13451357
1346- def _causal_impact_summary_stat (self ) -> str :
1358+ def _causal_impact_summary_stat (self , round_to ) -> str :
13471359 """Computes the mean and 94% credible interval bounds for the causal impact."""
13481360 percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
1349- ci = r"$CI_{94%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
1361+ ci = (
1362+ r"$CI_{94%}$"
1363+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
1364+ )
13501365 causal_impact = f"{ self .causal_impact .mean ():.2f} , "
13511366 return f"Causal impact = { causal_impact + ci } "
13521367
1353- def summary (self ) -> None :
1368+ def summary (self , round_to = None ) -> None :
13541369 """
13551370 Print text output summarising the results
13561371 """
@@ -1359,7 +1374,7 @@ def summary(self) -> None:
13591374 print (f"Formula: { self .formula } " )
13601375 print ("\n Results:" )
13611376 # TODO: extra experiment specific outputs here
1362- print (self ._causal_impact_summary_stat ())
1377+ print (self ._causal_impact_summary_stat (round_to ))
13631378 self .print_coefficients ()
13641379
13651380 def _get_treatment_effect_coeff (self ) -> str :
0 commit comments