2323from patsy import build_design_matrices , dmatrices
2424from sklearn .linear_model import LinearRegression as sk_lin_reg
2525
26- from causalpy .custom_exceptions import BadIndexException
27- from causalpy .custom_exceptions import DataException , FormulaException
26+ from causalpy .custom_exceptions import (
27+ BadIndexException , # NOQA
28+ DataException ,
29+ FormulaException ,
30+ )
2831from causalpy .plot_utils import plot_xY
29- from causalpy .utils import _is_variable_dummy_coded
32+ from causalpy .utils import _is_variable_dummy_coded , round_num
3033
3134LEGEND_FONT_SIZE = 12
3235az .style .use ("arviz-darkgrid" )
@@ -228,9 +231,12 @@ def _input_validation(self, data, treatment_time):
228231 "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
229232 )
230233
231- def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
234+ def plot (self , counterfactual_label = "Counterfactual" , round_to = None , ** kwargs ):
232235 """
233236 Plot the results
237+
238+ :param round_to:
239+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
234240 """
235241 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
236242
@@ -275,8 +281,8 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
275281
276282 ax [0 ].set (
277283 title = f"""
278- Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f }
279- (std = { self .score .r2_std :.3f } )
284+ Pre-intervention Bayesian $R^2$: { round_num ( self .score .r2 , round_to ) }
285+ (std = { round_num ( self .score .r2_std , round_to ) } )
280286 """
281287 )
282288
@@ -416,7 +422,11 @@ class SyntheticControl(PrePostFit):
416422 expt_type = "Synthetic Control"
417423
418424 def plot (self , plot_predictors = False , ** kwargs ):
419- """Plot the results"""
425+ """Plot the results
426+
427+ :param round_to:
428+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
429+ """
420430 fig , ax = super ().plot (counterfactual_label = "Synthetic control" , ** kwargs )
421431 if plot_predictors :
422432 # plot control units as well
@@ -580,9 +590,11 @@ def _input_validation(self):
580590 coded. Consisting of 0's and 1's only."""
581591 )
582592
583- def plot (self ):
593+ def plot (self , round_to = None ):
584594 """Plot the results.
585- Creating the combined mean + HDI legend entries is a bit involved.
595+
596+ :param round_to:
597+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
586598 """
587599 fig , ax = plt .subplots ()
588600
@@ -658,7 +670,7 @@ def plot(self):
658670 # formatting
659671 ax .set (
660672 xticks = self .x_pred_treatment [self .time_variable_name ].values ,
661- title = self ._causal_impact_summary_stat (),
673+ title = self ._causal_impact_summary_stat (round_to ),
662674 )
663675 ax .legend (
664676 handles = (h_tuple for h_tuple in handles ),
@@ -711,11 +723,14 @@ def _plot_causal_impact_arrow(self, ax):
711723 va = "center" ,
712724 )
713725
714- def _causal_impact_summary_stat (self ) -> str :
726+ def _causal_impact_summary_stat (self , round_to = None ) -> str :
715727 """Computes the mean and 94% credible interval bounds for the causal impact."""
716728 percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
717- ci = "$CI_{94\\ %}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
718- causal_impact = f"{ self .causal_impact .mean ():.2f} , "
729+ ci = (
730+ "$CI_{94\\ %}$"
731+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
732+ )
733+ causal_impact = f"{ round_num (self .causal_impact .mean (), round_to )} , "
719734 return f"Causal impact = { causal_impact + ci } "
720735
721736 def summary (self ) -> None :
@@ -893,9 +908,12 @@ def _is_treated(self, x):
893908 """
894909 return np .greater_equal (x , self .treatment_threshold )
895910
896- def plot (self ):
911+ def plot (self , round_to = None ):
897912 """
898913 Plot the results
914+
915+ :param round_to:
916+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
899917 """
900918 fig , ax = plt .subplots ()
901919 # Plot raw data
@@ -918,12 +936,15 @@ def plot(self):
918936 labels = ["Posterior mean" ]
919937
920938 # create strings to compose title
921- title_info = f"{ self .score .r2 :.3f } (std = { self .score .r2_std :.3f } )"
939+ title_info = f"{ round_num ( self .score .r2 , round_to ) } (std = { round_num ( self .score .r2_std , round_to ) } )"
922940 r2 = f"Bayesian $R^2$ on all data = { title_info } "
923941 percentiles = self .discontinuity_at_threshold .quantile ([0.03 , 1 - 0.03 ]).values
924- ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
942+ ci = (
943+ r"$CI_{94\%}$"
944+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
945+ )
925946 discon = f"""
926- Discontinuity at threshold = { self .discontinuity_at_threshold .mean ():.2f } ,
947+ Discontinuity at threshold = { round_num ( self .discontinuity_at_threshold .mean (), round_to ) } ,
927948 """
928949 ax .set (title = r2 + "\n " + discon + ci )
929950 # Intervention line
@@ -1104,9 +1125,12 @@ def _is_treated(self, x):
11041125 """Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
11051126 return np .greater_equal (x , self .kink_point )
11061127
1107- def plot (self ):
1128+ def plot (self , round_to = None ):
11081129 """
11091130 Plot the results
1131+
1132+ :param round_to:
1133+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
11101134 """
11111135 fig , ax = plt .subplots ()
11121136 # Plot raw data
@@ -1129,12 +1153,15 @@ def plot(self):
11291153 labels = ["Posterior mean" ]
11301154
11311155 # create strings to compose title
1132- title_info = f"{ self .score .r2 :.3f } (std = { self .score .r2_std :.3f } )"
1156+ title_info = f"{ round_num ( self .score .r2 , round_to ) } (std = { round_num ( self .score .r2_std , round_to ) } )"
11331157 r2 = f"Bayesian $R^2$ on all data = { title_info } "
11341158 percentiles = self .gradient_change .quantile ([0.03 , 1 - 0.03 ]).values
1135- ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
1159+ ci = (
1160+ r"$CI_{94\%}$"
1161+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
1162+ )
11361163 grad_change = f"""
1137- Change in gradient = { self .gradient_change .mean ():.2f } ,
1164+ Change in gradient = { round_num ( self .gradient_change .mean (), round_to ) } ,
11381165 """
11391166 ax .set (title = r2 + "\n " + grad_change + ci )
11401167 # Intervention line
@@ -1210,9 +1237,9 @@ class PrePostNEGD(ExperimentalDesign):
12101237 Formula: post ~ 1 + C(group) + pre
12111238 <BLANKLINE>
12121239 Results:
1213- Causal impact = 1.8, $CI_{94%}$[1.6 , 2.0 ]
1240+ Causal impact = 1.8, $CI_{94%}$[1.7 , 2.1 ]
12141241 Model coefficients:
1215- Intercept -0.4, 94% HDI [-1.2 , 0.2]
1242+ Intercept -0.4, 94% HDI [-1.1 , 0.2]
12161243 C(group)[T.1] 1.8, 94% HDI [1.6, 2.0]
12171244 pre 1.0, 94% HDI [0.9, 1.1]
12181245 sigma 0.5, 94% HDI [0.4, 0.5]
@@ -1292,8 +1319,12 @@ def _input_validation(self) -> None:
12921319 """
12931320 )
12941321
1295- def plot (self ):
1296- """Plot the results"""
1322+ def plot (self , round_to = None ):
1323+ """Plot the results
1324+
1325+ :param round_to:
1326+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1327+ """
12971328 fig , ax = plt .subplots (
12981329 2 , 1 , figsize = (7 , 9 ), gridspec_kw = {"height_ratios" : [3 , 1 ]}
12991330 )
@@ -1339,18 +1370,21 @@ def plot(self):
13391370 )
13401371
13411372 # Plot estimated caual impact / treatment effect
1342- az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ])
1373+ az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ], round_to = round_to )
13431374 ax [1 ].set (title = "Estimated treatment effect" )
13441375 return fig , ax
13451376
1346- def _causal_impact_summary_stat (self ) -> str :
1377+ def _causal_impact_summary_stat (self , round_to ) -> str :
13471378 """Computes the mean and 94% credible interval bounds for the causal impact."""
13481379 percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
1349- ci = r"$CI_{94%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
1380+ ci = (
1381+ r"$CI_{94%}$"
1382+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
1383+ )
13501384 causal_impact = f"{ self .causal_impact .mean ():.2f} , "
13511385 return f"Causal impact = { causal_impact + ci } "
13521386
1353- def summary (self ) -> None :
1387+ def summary (self , round_to = None ) -> None :
13541388 """
13551389 Print text output summarising the results
13561390 """
@@ -1359,7 +1393,7 @@ def summary(self) -> None:
13591393 print (f"Formula: { self .formula } " )
13601394 print ("\n Results:" )
13611395 # TODO: extra experiment specific outputs here
1362- print (self ._causal_impact_summary_stat ())
1396+ print (self ._causal_impact_summary_stat (round_to ))
13631397 self .print_coefficients ()
13641398
13651399 def _get_treatment_effect_coeff (self ) -> str :
0 commit comments