11"""
22Experiments for Scikit-Learn models
3+
4+ - ExperimentalDesign: base class for skl experiments
5+ - PrePostFit: base class for synthetic control and interrupted time series
6+ - SyntheticControl
7+ - InterruptedTimeSeries
8+ - DifferenceInDifferences
9+ - RegressionDiscontinuity
310"""
411import warnings
512from typing import Optional
@@ -27,8 +34,33 @@ def __init__(self, model=None, **kwargs):
2734
2835
2936class PrePostFit (ExperimentalDesign ):
30- """A class to analyse quasi-experiments where parameter estimation is based on just
31- the pre-intervention data."""
37+ """
38+ A class to analyse quasi-experiments where parameter estimation is based on just
39+ the pre-intervention data.
40+
41+ :param data:
42+ A pandas data frame
43+ :param treatment_time:
44+ The index or time value of when treatment begins
45+ :param formula:
46+ A statistical model formula
47+ :param model:
48+ An sklearn model object
49+
50+ Example
51+ --------
52+ >>> from sklearn.linear_model import LinearRegression
53+ >>> import causalpy as cp
54+ >>> df = cp.load_data("sc")
55+ >>> treatment_time = 70
56+ >>> result = cp.skl_experiments.PrePostFit(
57+ ... df,
58+ ... treatment_time,
59+ ... formula="actual ~ 0 + a + b + c + d + e + f + g",
60+ ... model = cp.skl_models.WeightedProportion()
61+ ... )
62+
63+ """
3264
3365 def __init__ (
3466 self ,
@@ -144,7 +176,16 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
144176 return (fig , ax )
145177
146178 def get_coeffs (self ):
147- """Returns model coefficients"""
179+ """
180+ Returns model coefficients
181+
182+ Example
183+ --------
184+ >>> result.get_coeffs()
185+ array([3.97370896e-01, 1.53881980e-01, 4.48747123e-01, 1.04639857e-16,
186+ 0.00000000e+00, 0.00000000e+00, 2.92931287e-16])
187+
188+ """
148189 return np .squeeze (self .model .coef_ )
149190
150191 def plot_coeffs (self ):
@@ -161,13 +202,68 @@ def plot_coeffs(self):
161202
162203
163204class InterruptedTimeSeries (PrePostFit ):
164- """Interrupted time series analysis"""
205+ """
206+ Interrupted time series analysis, a wrapper around the PrePostFit class
207+
208+ :param data:
209+ A pandas data frame
210+ :param treatment_time:
211+ The index or time value of when treatment begins
212+ :param formula:
213+ A statistical model formula
214+ :param model:
215+ An sklearn model object
216+
217+ Example
218+ --------
219+ >>> from sklearn.linear_model import LinearRegression
220+ >>> import pandas as pd
221+ >>> import causalpy as cp
222+ >>> df = (
223+ ... cp.load_data("its")
224+ ... .assign(date=lambda x: pd.to_datetime(x["date"]))
225+ ... .set_index("date")
226+ ... )
227+ >>> treatment_time = pd.to_datetime("2017-01-01")
228+ >>> result = cp.skl_experiments.InterruptedTimeSeries(
229+ ... df,
230+ ... treatment_time,
231+ ... formula="y ~ 1 + t + C(month)",
232+ ... model = LinearRegression()
233+ ... )
234+
235+ """
165236
166237 expt_type = "Interrupted Time Series"
167238
168239
169240class SyntheticControl (PrePostFit ):
170- """A wrapper around the PrePostFit class"""
241+ """
242+ A wrapper around the PrePostFit class
243+
244+ :param data:
245+ A pandas data frame
246+ :param treatment_time:
247+ The index or time value of when treatment begins
248+ :param formula:
249+ A statistical model formula
250+ :param model:
251+ An sklearn model object
252+
253+ Example
254+ --------
255+ >>> from sklearn.linear_model import LinearRegression
256+ >>> import causalpy as cp
257+ >>> df = cp.load_data("sc")
258+ >>> treatment_time = 70
259+ >>> result = cp.skl_experiments.SyntheticControl(
260+ ... df,
261+ ... treatment_time,
262+ ... formula="actual ~ 0 + a + b + c + d + e + f + g",
263+ ... model = cp.skl_models.WeightedProportion()
264+ ... )
265+
266+ """
171267
172268 def plot (self , plot_predictors = False , ** kwargs ):
173269 """Plot the results"""
@@ -187,6 +283,32 @@ class DifferenceInDifferences(ExperimentalDesign):
187283
188284 There is no pre/post intervention data distinction for DiD, we fit all the data
189285 available.
286+
287+ :param data:
288+ A pandas data frame
289+ :param formula:
290+ A statistical model formula
291+ :param time_variable_name:
292+ Name of the data column for the time variable
293+ :param group_variable_name:
294+ Name of the data column for the group variable
295+ :param model:
296+ A PyMC model for difference in differences
297+
298+ Example
299+ --------
300+ >>> df = cp.load_data("did")
301+ >>> seed = 42
302+ >>> result = cp.skl_experiments.DifferenceInDifferences(
303+ ... data,
304+ ... formula="y ~ 1 + group*post_treatment",
305+ ... time_variable_name="t",
306+ ... group_variable_name="group",
307+ ... treated=1,
308+ ... untreated=0,
309+ ... model=LinearRegression(),
310+ ... )
311+
190312 """
191313
192314 def __init__ (
@@ -373,6 +495,17 @@ class RegressionDiscontinuity(ExperimentalDesign):
373495 :param bandwidth:
374496 Data outside of the bandwidth (relative to the discontinuity) is not used to fit
375497 the model.
498+
499+ Example
500+ --------
501+ >>> data = cp.load_data("rd")
502+ >>> result = cp.skl_experiments.RegressionDiscontinuity(
503+ ... data,
504+ ... formula="y ~ 1 + x + treated",
505+ ... model=LinearRegression(),
506+ ... treatment_threshold=0.5,
507+ ... )
508+
376509 """
377510
378511 def __init__ (
@@ -503,7 +636,24 @@ def plot(self):
503636 return (fig , ax )
504637
505638 def summary (self ):
506- """Print text output summarising the results"""
639+ """
640+ Print text output summarising the results
641+
642+ Example
643+ --------
644+ >>> result.summary()
645+ Difference in Differences experiment
646+ Formula: y ~ 1 + x + treated
647+ Running variable: x
648+ Threshold on running variable: 0.5
649+ Results:
650+ Discontinuity at threshold = 0.19
651+ Model coefficients:
652+ Intercept 0.0
653+ treated[T.True] 0.19034196317793994
654+ x 1.229600855360073
655+
656+ """
507657 print ("Difference in Differences experiment" )
508658 print (f"Formula: { self .formula } " )
509659 print (f"Running variable: { self .running_variable_name } " )
0 commit comments