3030
3131
3232class ExperimentalDesign :
33- """Base class"""
33+ """Base class for other experiment types """
3434
3535 model = None
3636 expt_type = None
@@ -43,7 +43,7 @@ def __init__(self, model=None, **kwargs):
4343
4444 @property
4545 def idata (self ):
46- """Access to the InferenceData object"""
46+ """Access to the models InferenceData object"""
4747 return self .model .idata
4848
4949 def print_coefficients (self ) -> None :
@@ -66,8 +66,32 @@ def print_coefficients(self) -> None:
6666
6767
6868class PrePostFit (ExperimentalDesign ):
69- """A class to analyse quasi-experiments where parameter estimation is based on just
70- the pre-intervention data."""
69+ """
70+ A class to analyse quasi-experiments where parameter estimation is based on just
71+ the pre-intervention data.
72+
73+ :param data:
74+ A pandas data frame
75+ :param treatment_time:
76+ The time when treatment occured, should be in reference to the data index
77+ :param formula:
78+ A statistical model formula
79+ :param model:
80+ A PyMC model
81+
82+ Example
83+ --------
84+ >>> sc = cp.load_data("sc")
85+ >>> seed = 42
86+ >>> result = cp.pymc_experiments.PrePostFit(
87+ ... sc,
88+ ... treatment_time,
89+ ... formula="actual ~ 0 + a + b + c + d + e + f + g",
90+ ... model=cp.pymc_models.WeightedSumFitter(
91+ ... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
92+ ... ),
93+ ... )
94+ """
7195
7296 def __init__ (
7397 self ,
@@ -256,13 +280,64 @@ def summary(self) -> None:
256280
257281
258282class InterruptedTimeSeries (PrePostFit ):
259- """Interrupted time series analysis"""
283+ """
284+ A wrapper around PrePostFit class
285+
286+ :param data:
287+ A pandas data frame
288+ :param treatment_time:
289+ The time when treatment occured, should be in reference to the data index
290+ :param formula:
291+ A statistical model formula
292+ :param model:
293+ A PyMC model
294+
295+ Example
296+ --------
297+ >>> df = (
298+ ... cp.load_data("its")
299+ ... .assign(date=lambda x: pd.to_datetime(x["date"]))
300+ ... .set_index("date")
301+ ... )
302+ >>> treatment_time = pd.to_datetime("2017-01-01")
303+ >>> seed = 42
304+ >>> result = cp.pymc_experiments.InterruptedTimeSeries(
305+ ... df,
306+ ... treatment_time,
307+ ... formula="y ~ 1 + t + C(month)",
308+ ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
309+ ... )
310+ """
260311
261312 expt_type = "Interrupted Time Series"
262313
263314
264315class SyntheticControl (PrePostFit ):
265- """A wrapper around the PrePostFit class"""
316+ """A wrapper around the PrePostFit class
317+
318+ :param data:
319+ A pandas data frame
320+ :param treatment_time:
321+ The time when treatment occured, should be in reference to the data index
322+ :param formula:
323+ A statistical model formula
324+ :param model:
325+ A PyMC model
326+
327+ Example
328+ --------
329+ >>> df = cp.load_data("sc")
330+ >>> treatment_time = 70
331+ >>> seed = 42
332+ >>> result = cp.pymc_experiments.SyntheticControl(
333+ ... df,
334+ ... treatment_time,
335+ ... formula="actual ~ 0 + a + b + c + d + e + f + g",
336+ ... model=cp.pymc_models.WeightedSumFitter(
337+ ... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
338+ ... ),
339+ ... )
340+ """
266341
267342 expt_type = "Synthetic Control"
268343
@@ -285,6 +360,28 @@ class DifferenceInDifferences(ExperimentalDesign):
285360
286361 There is no pre/post intervention data distinction for DiD, we fit all the
287362 data available.
363+ :param data:
364+ A pandas data frame
365+ :param formula:
366+ A statistical model formula
367+ :param time_variable_name:
368+ Name of the data column for the time variable
369+ :param group_variable_name:
370+ Name of the data column for the group variable
371+ :param model:
372+ A PyMC model for difference in differences
373+
374+ Example
375+ --------
376+ >>> df = cp.load_data("did")
377+ >>> seed = 42
378+ >>> result = cp.pymc_experiments.DifferenceInDifferences(
379+ ... df,
380+ ... formula="y ~ 1 + group*post_treatment",
381+ ... time_variable_name="t",
382+ ... group_variable_name="group",
383+ ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
384+ ... )
288385
289386 """
290387
@@ -572,6 +669,18 @@ class RegressionDiscontinuity(ExperimentalDesign):
572669 :param bandwidth:
573670 Data outside of the bandwidth (relative to the discontinuity) is not used to fit
574671 the model.
672+
673+ Example
674+ --------
675+ >>> df = cp.load_data("rd")
676+ >>> seed = 42
677+ >>> result = cp.pymc_experiments.RegressionDiscontinuity(
678+ ... df,
679+ ... formula="y ~ 1 + x + treated + x:treated",
680+ ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
681+ ... treatment_threshold=0.5,
682+ ... )
683+
575684 """
576685
577686 def __init__ (
@@ -742,7 +851,33 @@ def summary(self) -> None:
742851
743852
744853class PrePostNEGD (ExperimentalDesign ):
745- """A class to analyse data from pretest/posttest designs"""
854+ """
855+ A class to analyse data from pretest/posttest designs
856+
857+ :param data:
858+ A pandas data frame
859+ :param formula:
860+ A statistical model formula
861+ :param group_variable_name:
862+ Name of the column in data for the group variable
863+ :param pretreatment_variable_name:
864+ Name of the column in data for the pretreatment variable
865+ :param model:
866+ A PyMC model
867+
868+ Example
869+ --------
870+ >>> df = cp.load_data("anova1")
871+ >>> seed = 42
872+ >>> result = cp.pymc_experiments.PrePostNEGD(
873+ ... df,
874+ ... formula="post ~ 1 + C(group) + pre",
875+ ... group_variable_name="group",
876+ ... pretreatment_variable_name="pre",
877+ ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
878+ ... )
879+
880+ """
746881
747882 def __init__ (
748883 self ,
0 commit comments