@@ -258,3 +258,28 @@ def test_geolift1():
258258 assert isinstance (result , cp .pymc_experiments .SyntheticControl )
259259 assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
260260 assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
261+
262+
263+ @pytest .mark .integration
264+ def test_iv_reg ():
265+ df = cp .load_data ("risk" )
266+ instruments_formula = "risk ~ 1 + logmort0"
267+ formula = "loggdp ~ 1 + risk"
268+ instruments_data = df [["risk" , "logmort0" ]]
269+ data = df [["loggdp" , "risk" ]]
270+
271+ result = cp .pymc_experiments .InstrumentalVariable (
272+ instruments_data = instruments_data ,
273+ data = data ,
274+ instruments_formula = instruments_formula ,
275+ formula = formula ,
276+ model = cp .pymc_models .InstrumentalVariableRegression (
277+ sample_kwargs = sample_kwargs
278+ ),
279+ )
280+ assert isinstance (df , pd .DataFrame )
281+ assert isinstance (data , pd .DataFrame )
282+ assert isinstance (instruments_data , pd .DataFrame )
283+ assert isinstance (result , cp .pymc_experiments .InstrumentalVariable )
284+ assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
285+ assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
0 commit comments