|
| 1 | +""" |
| 2 | +Experiment routines for PyMC models. |
| 3 | +
|
| 4 | +Includes: |
| 5 | +1. ExperimentalDesign base class |
| 6 | +2. Pre-Post Fit |
| 7 | +3. Synthetic Control |
| 8 | +4. Difference in differences |
| 9 | +5. Regression Discontinuity |
| 10 | +""" |
1 | 11 | import warnings |
2 | 12 | from typing import Optional, Union |
3 | 13 |
|
@@ -36,7 +46,7 @@ def idata(self): |
36 | 46 | """Access to the InferenceData object""" |
37 | 47 | return self.model.idata |
38 | 48 |
|
39 | | - def print_coefficients(self): |
| 49 | + def print_coefficients(self) -> None: |
40 | 50 | """Prints the model coefficients""" |
41 | 51 | print("Model coefficients:") |
42 | 52 | coeffs = az.extract(self.idata.posterior, var_names="beta") |
@@ -236,7 +246,7 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs): |
236 | 246 |
|
237 | 247 | return (fig, ax) |
238 | 248 |
|
239 | | - def summary(self): |
| 249 | + def summary(self) -> None: |
240 | 250 | """Print text output summarising the results""" |
241 | 251 |
|
242 | 252 | print(f"{self.expt_type:=^80}") |
@@ -524,13 +534,14 @@ def _plot_causal_impact_arrow(self, ax): |
524 | 534 | va="center", |
525 | 535 | ) |
526 | 536 |
|
527 | | - def _causal_impact_summary_stat(self): |
| 537 | + def _causal_impact_summary_stat(self) -> str: |
| 538 | + """Computes the mean and 94% credible interval bounds for the causal impact.""" |
528 | 539 | percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values |
529 | 540 | ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]" |
530 | 541 | causal_impact = f"{self.causal_impact.mean():.2f}, " |
531 | 542 | return f"Causal impact = {causal_impact + ci}" |
532 | 543 |
|
533 | | - def summary(self): |
| 544 | + def summary(self) -> None: |
534 | 545 | """Print text output summarising the results""" |
535 | 546 |
|
536 | 547 | print(f"{self.expt_type:=^80}") |
@@ -716,7 +727,7 @@ def plot(self): |
716 | 727 | ) |
717 | 728 | return (fig, ax) |
718 | 729 |
|
719 | | - def summary(self): |
| 730 | + def summary(self) -> None: |
720 | 731 | """Print text output summarising the results""" |
721 | 732 |
|
722 | 733 | print(f"{self.expt_type:=^80}") |
@@ -795,7 +806,7 @@ def __init__( |
795 | 806 |
|
796 | 807 | # ================================================================ |
797 | 808 |
|
798 | | - def _input_validation(self): |
| 809 | + def _input_validation(self) -> None: |
799 | 810 | """Validate the input data and model formula for correctness""" |
800 | 811 | if not _series_has_2_levels(self.data[self.group_variable_name]): |
801 | 812 | raise DataException( |
@@ -856,13 +867,14 @@ def plot(self): |
856 | 867 | ax[1].set(title="Estimated treatment effect") |
857 | 868 | return fig, ax |
858 | 869 |
|
859 | | - def _causal_impact_summary_stat(self): |
| 870 | + def _causal_impact_summary_stat(self) -> str: |
| 871 | + """Computes the mean and 94% credible interval bounds for the causal impact.""" |
860 | 872 | percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values |
861 | 873 | ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]" |
862 | 874 | causal_impact = f"{self.causal_impact.mean():.2f}, " |
863 | 875 | return f"Causal impact = {causal_impact + ci}" |
864 | 876 |
|
865 | | - def summary(self): |
| 877 | + def summary(self) -> None: |
866 | 878 | """Print text output summarising the results""" |
867 | 879 |
|
868 | 880 | print(f"{self.expt_type:=^80}") |
|
0 commit comments