@@ -11,6 +11,7 @@ def model_c():
1111 m = pm .Normal ("m" )
1212 s = pm .LogNormal ("s" )
1313 pm .Normal ("g" , m , s , shape = 5 )
14+ pm .Exponential ("e" , scale = s , shape = 7 )
1415 return mod
1516
1617
@@ -20,31 +21,34 @@ def model_nc():
2021 m = pm .Normal ("m" )
2122 s = pm .LogNormal ("s" )
2223 pm .Deterministic ("g" , pm .Normal ("z" , shape = 5 ) * s + m )
24+ pm .Deterministic ("e" , pm .Exponential ("z_e" , 1 , shape = 7 ) * s )
2325 return mod
2426
2527
26- def test_reparametrize_created (model_c : pm .Model ):
27- model_reparam , vip = vip_reparametrize (model_c , ["g" ])
28- assert "g" in vip .get_lambda ()
29- assert "g::lam_logit__" in model_reparam .named_vars
30- assert "g::tau_" in model_reparam .named_vars
28+ @pytest .mark .parametrize ("var" , ["g" , "e" ])
29+ def test_reparametrize_created (model_c : pm .Model , var ):
30+ model_reparam , vip = vip_reparametrize (model_c , [var ])
31+ assert f"{ var } " in vip .get_lambda ()
32+ assert f"{ var } ::lam_logit__" in model_reparam .named_vars
33+ assert f"{ var } ::tau_" in model_reparam .named_vars
3134 vip .set_all_lambda (1 )
32- assert ~ np .isfinite (model_reparam ["g ::lam_logit__" ].get_value ()).any ()
35+ assert ~ np .isfinite (model_reparam [f" { var } ::lam_logit__" ].get_value ()).any ()
3336
3437
35- def test_random_draw (model_c : pm .Model , model_nc ):
38+ @pytest .mark .parametrize ("var" , ["g" , "e" ])
39+ def test_random_draw (model_c : pm .Model , model_nc , var ):
3640 model_c = pm .do (model_c , {"m" : 3 , "s" : 2 })
3741 model_nc = pm .do (model_nc , {"m" : 3 , "s" : 2 })
38- model_v , vip = vip_reparametrize (model_c , ["g" ])
39- assert "g" in [v .name for v in model_v .deterministics ]
40- c = pm .draw (model_c ["g" ], random_seed = 42 , draws = 1000 )
41- nc = pm .draw (model_nc ["g" ], random_seed = 42 , draws = 1000 )
42+ model_v , vip = vip_reparametrize (model_c , [var ])
43+ assert var in [v .name for v in model_v .deterministics ]
44+ c = pm .draw (model_c [var ], random_seed = 42 , draws = 1000 )
45+ nc = pm .draw (model_nc [var ], random_seed = 42 , draws = 1000 )
4246 vip .set_all_lambda (1 )
43- v_1 = pm .draw (model_v ["g" ], random_seed = 42 , draws = 1000 )
47+ v_1 = pm .draw (model_v [var ], random_seed = 42 , draws = 1000 )
4448 vip .set_all_lambda (0 )
45- v_0 = pm .draw (model_v ["g" ], random_seed = 42 , draws = 1000 )
49+ v_0 = pm .draw (model_v [var ], random_seed = 42 , draws = 1000 )
4650 vip .set_all_lambda (0.5 )
47- v_05 = pm .draw (model_v ["g" ], random_seed = 42 , draws = 1000 )
51+ v_05 = pm .draw (model_v [var ], random_seed = 42 , draws = 1000 )
4852 np .testing .assert_allclose (c .mean (), nc .mean ())
4953 np .testing .assert_allclose (c .mean (), v_0 .mean ())
5054 np .testing .assert_allclose (v_05 .mean (), v_1 .mean ())
@@ -57,10 +61,12 @@ def test_random_draw(model_c: pm.Model, model_nc):
5761
5862
5963def test_reparam_fit (model_c ):
60- model_v , vip = vip_reparametrize (model_c , ["g" ])
64+ vars = ["g" , "e" ]
65+ model_v , vip = vip_reparametrize (model_c , ["g" , "e" ])
6166 with model_v :
62- vip .fit (random_seed = 42 )
63- np .testing .assert_allclose (vip .get_lambda ()["g" ], 0 , atol = 0.01 )
67+ vip .fit (50000 , random_seed = 42 )
68+ for var in vars :
69+ np .testing .assert_allclose (vip .get_lambda ()[var ], 0 , atol = 0.01 )
6470
6571
6672def test_multilevel ():
0 commit comments