@@ -35,6 +35,17 @@ def test_did_validation_post_treatment_formula():
3535 model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
3636 )
3737
38+ with pytest .raises (FormulaException ):
39+ _ = cp .skl_experiments .DifferenceInDifferences (
40+ df ,
41+ formula = "y ~ 1 + group*post_SOMETHING" ,
42+ time_variable_name = "t" ,
43+ group_variable_name = "group" ,
44+ treated = 1 ,
45+ untreated = 0 ,
46+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
47+ )
48+
3849
3950def test_did_validation_post_treatment_data ():
4051 """Test that we get a DataException if do not include post_treatment in the data"""
@@ -57,6 +68,17 @@ def test_did_validation_post_treatment_data():
5768 model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
5869 )
5970
71+ with pytest .raises (DataException ):
72+ _ = cp .skl_experiments .DifferenceInDifferences (
73+ df ,
74+ formula = "y ~ 1 + group*post_treatment" ,
75+ time_variable_name = "t" ,
76+ group_variable_name = "group" ,
77+ treated = 1 ,
78+ untreated = 0 ,
79+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
80+ )
81+
6082
6183def test_did_validation_unit_data ():
6284 """Test that we get a DataException if do not include unit in the data"""
@@ -79,6 +101,17 @@ def test_did_validation_unit_data():
79101 model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
80102 )
81103
104+ with pytest .raises (DataException ):
105+ _ = cp .skl_experiments .DifferenceInDifferences (
106+ df ,
107+ formula = "y ~ 1 + group*post_treatment" ,
108+ time_variable_name = "t" ,
109+ group_variable_name = "group" ,
110+ treated = 1 ,
111+ untreated = 0 ,
112+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
113+ )
114+
82115
83116def test_did_validation_group_dummy_coded ():
84117 """Test that we get a DataException if the group variable is not dummy coded"""
@@ -101,6 +134,17 @@ def test_did_validation_group_dummy_coded():
101134 model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
102135 )
103136
137+ with pytest .raises (DataException ):
138+ _ = cp .skl_experiments .DifferenceInDifferences (
139+ df ,
140+ formula = "y ~ 1 + group*post_treatment" ,
141+ time_variable_name = "t" ,
142+ group_variable_name = "group" ,
143+ treated = 1 ,
144+ untreated = 0 ,
145+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
146+ )
147+
104148
105149# Synthetic Control
106150
@@ -118,6 +162,16 @@ def test_sc_input_error():
118162 model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs ),
119163 )
120164
165+ with pytest .raises (BadIndexException ):
166+ df = cp .load_data ("sc" )
167+ treatment_time = pd .to_datetime ("2016 June 24" )
168+ _ = cp .skl_experiments .SyntheticControl (
169+ df ,
170+ treatment_time ,
171+ formula = "actual ~ 0 + a + b + c + d + e + f + g" ,
172+ model = cp .skl_models .WeightedProportion (),
173+ )
174+
121175
122176def test_sc_brexit_input_error ():
123177 """Confirm a BadIndexException is raised if the data index is datetime and the
@@ -187,6 +241,16 @@ def test_rd_validation_treated_in_formula():
187241 treatment_threshold = 0.5 ,
188242 )
189243
244+ with pytest .raises (FormulaException ):
245+ from sklearn .linear_model import LinearRegression
246+
247+ _ = cp .skl_experiments .RegressionDiscontinuity (
248+ df ,
249+ formula = "y ~ 1 + x" ,
250+ model = LinearRegression (),
251+ treatment_threshold = 0.5 ,
252+ )
253+
190254
191255def test_rd_validation_treated_is_dummy ():
192256 """Test that we get a DataException if treated is not dummy coded"""
@@ -206,6 +270,16 @@ def test_rd_validation_treated_is_dummy():
206270 treatment_threshold = 0.5 ,
207271 )
208272
273+ from sklearn .linear_model import LinearRegression
274+
275+ with pytest .raises (DataException ):
276+ _ = cp .skl_experiments .RegressionDiscontinuity (
277+ df ,
278+ formula = "y ~ 1 + x + treated" ,
279+ model = LinearRegression (),
280+ treatment_threshold = 0.5 ,
281+ )
282+
209283
210284def test_iv_treatment_var_is_present ():
211285 """Test the treatment variable is present for Instrumental Variable experiment"""
0 commit comments