11"""
22Experiment routines for PyMC models.
33
4- Includes:
5- 1. ExperimentalDesign base class
6- 2. Pre-Post Fit
7- 3. Synthetic Control
8- 4. Difference in differences
9- 5. Regression Discontinuity
4+ - ExperimentalDesign base class
5+ - Pre-Post Fit
6+ - Interrupted Time Series
7+ - Synthetic Control
8+ - Difference in differences
9+ - Regression Discontinuity
10+ - Pretest/Posttest Nonequivalent Group Design
11+
1012"""
13+
1114import warnings
1215from typing import Optional , Union
1316
3033
3134
3235class ExperimentalDesign :
33- """Base class for other experiment types"""
36+ """
37+ Base class for other experiment types
38+
39+ See subclasses for examples of most methods
40+ """
3441
3542 model = None
3643 expt_type = None
@@ -43,11 +50,63 @@ def __init__(self, model=None, **kwargs):
4350
4451 @property
4552 def idata (self ):
46- """Access to the models InferenceData object"""
53+ """
54+ Access to the models InferenceData object
55+
56+ Example
57+ --------
58+ If `result` is the result of the Difference in Differences experiment example
59+
60+ >>> result.idata
61+ Inference data with groups:
62+ > posterior
63+ > posterior_predictive
64+ > sample_stats
65+ > prior
66+ > prior_predictive
67+ > observed_data
68+ > constant_data
69+ >>> result.idata.posterior
70+ <xarray.Dataset>
71+ Dimensions: (chain: 4, draw: 1000, coeffs: 4, obs_ind: 40)
72+ Coordinates:
73+ * chain (chain) int64 0 1 2 3
74+ * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998
75+ 999
76+ * coeffs (coeffs) <U28 'Intercept' ... 'group:post_treatment[T.True]'
77+ * obs_ind (obs_ind) int64 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37
78+ 38 39
79+ Data variables:
80+ beta (chain, draw, coeffs) float64 1.04 1.013 0.173 ... 0.1873 0.5225
81+ sigma (chain, draw) float64 0.09331 0.1031 0.1024 ... 0.0824 0.06907
82+ mu (chain, draw, obs_ind) float64 1.04 2.053 1.213 ... 1.265 2.747
83+ Attributes:
84+ created_at: 2023-08-23T20:03:45.709265
85+ arviz_version: 0.16.1
86+ inference_library: pymc
87+ inference_library_version: 5.7.2
88+ sampling_time: 0.8851289749145508
89+ tuning_steps: 1000
90+ """
91+
4792 return self .model .idata
4893
4994 def print_coefficients (self ) -> None :
50- """Prints the model coefficients"""
95+ """
96+ Prints the model coefficients
97+
98+ Example
99+ --------
100+ If `result` is from the Difference in Differences experiment example
101+
102+ >>> result.print_coefficients()
103+ Model coefficients:
104+ Intercept 1.08, 94% HDI [1.03, 1.13]
105+ post_treatment[T.True] 0.98, 94% HDI [0.91, 1.06]
106+ group 0.16, 94% HDI [0.09, 0.23]
107+ group:post_treatment[T.True] 0.51, 94% HDI [0.41, 0.61]
108+ sigma 0.08, 94% HDI [0.07, 0.10]
109+ """
51110 print ("Model coefficients:" )
52111 coeffs = az .extract (self .idata .posterior , var_names = "beta" )
53112 # Note: f"{name: <30}" pads the name with spaces so that we have alignment of
@@ -82,6 +141,7 @@ class PrePostFit(ExperimentalDesign):
82141 Example
83142 --------
84143 >>> sc = cp.load_data("sc")
144+ >>> treatment_time = 70
85145 >>> seed = 42
86146 >>> result = cp.pymc_experiments.PrePostFit(
87147 ... sc,
@@ -91,6 +151,17 @@ class PrePostFit(ExperimentalDesign):
91151 ... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
92152 ... ),
93153 ... )
154+ Auto-assigning NUTS sampler...
155+ Initializing NUTS using jitter+adapt_diag...
156+ Multiprocess sampling (4 chains in 4 jobs)
157+ NUTS: [beta, sigma]
158+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
159+ (4_000 + 4_000 draws total) took 11 seconds.
160+ Sampling: [beta, sigma, y_hat]
161+ Sampling: [y_hat]
162+ Sampling: [y_hat]
163+ Sampling: [y_hat]
164+ Sampling: [y_hat]
94165 """
95166
96167 def __init__ (
@@ -105,6 +176,8 @@ def __init__(
105176 self ._input_validation (data , treatment_time )
106177
107178 self .treatment_time = treatment_time
179+ # set experiment type - usually done in subclasses
180+ self .expt_type = "Pre-Post Fit"
108181 # split data in to pre and post intervention
109182 self .datapre = data [data .index <= self .treatment_time ]
110183 self .datapost = data [data .index > self .treatment_time ]
@@ -171,7 +244,14 @@ def _input_validation(self, data, treatment_time):
171244 )
172245
173246 def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
174- """Plot the results"""
247+ """
248+ Plot the results
249+
250+ Example
251+ --------
252+ >>> result.plot()
253+
254+ """
175255 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
176256
177257 # TOP PLOT --------------------------------------------------
@@ -271,7 +351,24 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
271351 return (fig , ax )
272352
273353 def summary (self ) -> None :
274- """Print text output summarising the results"""
354+ """
355+ Print text output summarising the results
356+
357+ Example
358+ ---------
359+ >>> result.summary()
360+ ===============================Synthetic Control===============================
361+ Formula: actual ~ 0 + a + b + c + d + e + f + g
362+ Model coefficients:
363+ a 0.33, 94% HDI [0.30, 0.38]
364+ b 0.05, 94% HDI [0.01, 0.09]
365+ c 0.31, 94% HDI [0.26, 0.35]
366+ d 0.06, 94% HDI [0.01, 0.10]
367+ e 0.02, 94% HDI [0.00, 0.06]
368+ f 0.20, 94% HDI [0.12, 0.26]
369+ g 0.04, 94% HDI [0.00, 0.08]
370+ sigma 0.26, 94% HDI [0.22, 0.30]
371+ """
275372
276373 print (f"{ self .expt_type :=^80} " )
277374 print (f"Formula: { self .formula } " )
@@ -307,6 +404,17 @@ class InterruptedTimeSeries(PrePostFit):
307404 ... formula="y ~ 1 + t + C(month)",
308405 ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
309406 ... )
407+ Auto-assigning NUTS sampler...
408+ Initializing NUTS using jitter+adapt_diag...
409+ Multiprocess sampling (4 chains in 4 jobs)
410+ NUTS: [beta, sigma]
411+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
412+ (4_000 + 4_000 draws total) took 3 seconds.
413+ Sampling: [beta, sigma, y_hat]
414+ Sampling: [y_hat]
415+ Sampling: [y_hat]
416+ Sampling: [y_hat]
417+ Sampling: [y_hat]
310418 """
311419
312420 expt_type = "Interrupted Time Series"
@@ -337,6 +445,17 @@ class SyntheticControl(PrePostFit):
337445 ... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
338446 ... ),
339447 ... )
448+ Auto-assigning NUTS sampler...
449+ Initializing NUTS using jitter+adapt_diag...
450+ Multiprocess sampling (4 chains in 4 jobs)
451+ NUTS: [beta, sigma]
452+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
453+ (4_000 + 4_000 draws total) took 11 seconds.
454+ Sampling: [beta, sigma, y_hat]
455+ Sampling: [y_hat]
456+ Sampling: [y_hat]
457+ Sampling: [y_hat]
458+ Sampling: [y_hat]
340459 """
341460
342461 expt_type = "Synthetic Control"
@@ -382,7 +501,17 @@ class DifferenceInDifferences(ExperimentalDesign):
382501 ... group_variable_name="group",
383502 ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
384503 ... )
385-
504+ Auto-assigning NUTS sampler...
505+ Initializing NUTS using jitter+adapt_diag...
506+ Multiprocess sampling (4 chains in 4 jobs)
507+ NUTS: [beta, sigma]
508+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
509+ (4_000 + 4_000 draws total) took 1 seconds.
510+ Sampling: [beta, sigma, y_hat]
511+ Sampling: [y_hat]
512+ Sampling: [y_hat]
513+ Sampling: [y_hat]
514+ Sampling: [y_hat]
386515 """
387516
388517 def __init__ (
@@ -503,6 +632,12 @@ def _input_validation(self):
503632 def plot (self ):
504633 """Plot the results.
505634 Creating the combined mean + HDI legend entries is a bit involved.
635+
636+ Example
637+ --------
638+ Assuming `result` is the result of a DiD experiment:
639+
640+ >>> result.plot()
506641 """
507642 fig , ax = plt .subplots ()
508643
@@ -639,7 +774,25 @@ def _causal_impact_summary_stat(self) -> str:
639774 return f"Causal impact = { causal_impact + ci } "
640775
641776 def summary (self ) -> None :
642- """Print text output summarising the results"""
777+ """
778+ Print text output summarising the results
779+
780+ Example
781+ --------
782+ Assuming `result` is a DiD experiment
783+
784+ >>> result.summary()
785+ ==========================Difference in Differences=========================
786+ Formula: y ~ 1 + group*post_treatment
787+ Results:
788+ Causal impact = 0.51, $CI_{94%}$[0.41, 0.61]
789+ Model coefficients:
790+ Intercept 1.08, 94% HDI [1.03, 1.13]
791+ post_treatment[T.True] 0.98, 94% HDI [0.91, 1.06]
792+ group 0.16, 94% HDI [0.09, 0.23]
793+ group:post_treatment[T.True] 0.51, 94% HDI [0.41, 0.61]
794+ sigma 0.08, 94% HDI [0.07, 0.10]
795+ """
643796
644797 print (f"{ self .expt_type :=^80} " )
645798 print (f"Formula: { self .formula } " )
@@ -680,7 +833,17 @@ class RegressionDiscontinuity(ExperimentalDesign):
680833 ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
681834 ... treatment_threshold=0.5,
682835 ... )
683-
836+ Auto-assigning NUTS sampler...
837+ Initializing NUTS using jitter+adapt_diag...
838+ Multiprocess sampling (4 chains in 4 jobs)
839+ NUTS: [beta, sigma]
840+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
841+ (4_000 + 4_000 draws total) took 2 seconds.
842+ Sampling: [beta, sigma, y_hat]
843+ Sampling: [y_hat]
844+ Sampling: [y_hat]
845+ Sampling: [y_hat]
846+ Sampling: [y_hat]
684847 """
685848
686849 def __init__ (
@@ -791,7 +954,13 @@ def _is_treated(self, x):
791954 return np .greater_equal (x , self .treatment_threshold )
792955
793956 def plot (self ):
794- """Plot the results"""
957+ """
958+ Plot the results
959+
960+ Example
961+ --------
962+ >>> result.plot()
963+ """
795964 fig , ax = plt .subplots ()
796965 # Plot raw data
797966 sns .scatterplot (
@@ -837,7 +1006,25 @@ def plot(self):
8371006 return (fig , ax )
8381007
8391008 def summary (self ) -> None :
840- """Print text output summarising the results"""
1009+ """
1010+ Print text output summarising the results
1011+
1012+ Example
1013+ --------
1014+ >>> result.summary()
1015+ ============================Regression Discontinuity==========================
1016+ Formula: y ~ 1 + x + treated + x:treated
1017+ Running variable: x
1018+ Threshold on running variable: 0.5
1019+ Results:
1020+ Discontinuity at threshold = 0.92
1021+ Model coefficients:
1022+ Intercept 0.09, 94% HDI [0.00, 0.17]
1023+ treated[T.True] 2.48, 94% HDI [1.66, 3.27]
1024+ x 1.32, 94% HDI [1.14, 1.50]
1025+ x:treated[T.True] -3.12, 94% HDI [-4.17, -2.05]
1026+ sigma 0.35, 94% HDI [0.31, 0.41]
1027+ """
8411028
8421029 print (f"{ self .expt_type :=^80} " )
8431030 print (f"Formula: { self .formula } " )
@@ -876,7 +1063,16 @@ class PrePostNEGD(ExperimentalDesign):
8761063 ... pretreatment_variable_name="pre",
8771064 ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
8781065 ... )
879-
1066+ Auto-assigning NUTS sampler...
1067+ Initializing NUTS using jitter+adapt_diag...
1068+ Multiprocess sampling (4 chains in 4 jobs)
1069+ NUTS: [beta, sigma]
1070+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
1071+ (4_000 + 4_000 draws total) took 3 seconds.
1072+ Sampling: [beta, sigma, y_hat]
1073+ Sampling: [y_hat]
1074+ Sampling: [y_hat]
1075+ Sampling: [y_hat]
8801076 """
8811077
8821078 def __init__ (
@@ -1010,7 +1206,23 @@ def _causal_impact_summary_stat(self) -> str:
10101206 return f"Causal impact = { causal_impact + ci } "
10111207
10121208 def summary (self ) -> None :
1013- """Print text output summarising the results"""
1209+ """
1210+ Print text output summarising the results
1211+
1212+ Example
1213+ --------
1214+ >>> result.summary()
1215+ =================Pretest/posttest Nonequivalent Group Design================
1216+ Formula: post ~ 1 + C(group) + pre
1217+ Results:
1218+ Causal impact = 1.89, $CI_{94%}$[1.70, 2.07]
1219+ Model coefficients:
1220+ Intercept -0.46, 94% HDI [-1.17, 0.22]
1221+ C(group)[T.1] 1.89, 94% HDI [1.70, 2.07]
1222+ pre 1.05, 94% HDI [0.98, 1.12]
1223+ sigma 0.51, 94% HDI [0.46, 0.56]
1224+
1225+ """
10141226
10151227 print (f"{ self .expt_type :=^80} " )
10161228 print (f"Formula: { self .formula } " )
0 commit comments