@@ -1381,13 +1381,12 @@ def test_issue_3706(self):
13811381 assert prior_pred ["X" ].shape == (1 , N , 2 )
13821382
13831383
1384- COORDS = {
1385- "regions" : ["a" , "b" , "c" ],
1386- "answers" : ["yes" , "no" , "whatever" , "don't understand question" ],
1387- }
1388-
1389-
13901384class TestZeroSumNormal :
1385+ coords = {
1386+ "regions" : ["a" , "b" , "c" ],
1387+ "answers" : ["yes" , "no" , "whatever" , "don't understand question" ],
1388+ }
1389+
13911390 def assert_zerosum_axes (self , random_samples , axes_to_check , check_zerosum_axes = True ):
13921391 if check_zerosum_axes :
13931392 for ax in axes_to_check :
@@ -1409,14 +1408,19 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=
14091408 ],
14101409 )
14111410 def test_zsn_dims (self , dims , zerosum_axes ):
1412- with pm .Model (coords = COORDS ) as m :
1411+ with pm .Model (coords = self . coords ) as m :
14131412 v = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
14141413 s = pm .sample (10 , chains = 1 , tune = 100 )
14151414
14161415 # to test forward graph
14171416 random_samples = pm .draw (v , draws = 10 )
14181417
1419- assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
1418+ assert s .posterior .v .shape == (
1419+ 1 ,
1420+ 10 ,
1421+ len (self .coords ["regions" ]),
1422+ len (self .coords ["answers" ]),
1423+ )
14201424
14211425 ndim_supp = v .owner .op .ndim_supp
14221426 zerosum_axes = np .arange (- ndim_supp , 0 )
@@ -1429,22 +1433,25 @@ def test_zsn_dims(self, dims, zerosum_axes):
14291433 self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
14301434
14311435 @pytest .mark .parametrize (
1432- "zerosum_axes, shape" ,
1433- [
1434- (None , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1435- (1 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1436- (2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1437- ],
1436+ "zerosum_axes" ,
1437+ (None , 1 , 2 ),
14381438 )
1439- def test_zsn_shape (self , shape , zerosum_axes ):
1440- with pm .Model (coords = COORDS ) as m :
1439+ def test_zsn_shape (self , zerosum_axes ):
1440+ shape = (len (self .coords ["regions" ]), len (self .coords ["answers" ]))
1441+
1442+ with pm .Model (coords = self .coords ) as m :
14411443 v = pm .ZeroSumNormal ("v" , shape = shape , zerosum_axes = zerosum_axes )
14421444 s = pm .sample (10 , chains = 1 , tune = 100 )
14431445
14441446 # to test forward graph
14451447 random_samples = pm .draw (v , draws = 10 )
14461448
1447- assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
1449+ assert s .posterior .v .shape == (
1450+ 1 ,
1451+ 10 ,
1452+ len (self .coords ["regions" ]),
1453+ len (self .coords ["answers" ]),
1454+ )
14481455
14491456 ndim_supp = v .owner .op .ndim_supp
14501457 zerosum_axes = np .arange (- ndim_supp , 0 )
@@ -1525,13 +1532,13 @@ def test_zsn_change_dist_size(self, zerosum_axes):
15251532 )
15261533 def test_zsn_variance (self , sigma , n ):
15271534
1528- dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = n )
1529- random_samples = pm .draw (dist , draws = 100_000 )
1535+ dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = ( 100_000 , n ) )
1536+ random_samples = pm .draw (dist )
15301537
15311538 empirical_var = random_samples .var (axis = 0 )
15321539 theoretical_var = sigma ** 2 * (n - 1 ) / n
15331540
1534- np .testing .assert_allclose (empirical_var , theoretical_var , rtol = 1e-02 )
1541+ np .testing .assert_allclose (empirical_var , theoretical_var , atol = 0.4 )
15351542
15361543 @pytest .mark .parametrize (
15371544 "sigma, shape, zerosum_axes, mvn_axes" ,
0 commit comments