@@ -1415,10 +1415,7 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14151415 s = pm .sample (10 , chains = 1 , tune = 100 )
14161416
14171417 # to test forward graph
1418- random_samples = pm .draw (
1419- v ,
1420- draws = 10 ,
1421- )
1418+ random_samples = pm .draw (v , draws = 10 )
14221419
14231420 assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
14241421
@@ -1475,14 +1472,39 @@ def test_zsn_fail_axis(self, dims, zerosum_axes):
14751472 with pm .Model (coords = COORDS ) as m :
14761473 _ = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
14771474
1478- def test_zsn_change_dist_size (self ):
1479- base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ))
1475+ @pytest .mark .parametrize (
1476+ "zerosum_axes" ,
1477+ [(- 1 ), (- 2 ), (1 ), ((0 , 1 )), ((- 2 , - 1 ))],
1478+ )
1479+ def test_zsn_change_dist_size (self , zerosum_axes ):
1480+ base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), zerosum_axes = zerosum_axes )
1481+ random_samples = pm .draw (base_dist , draws = 100 )
1482+
1483+ if not isinstance (zerosum_axes , (list , tuple )):
1484+ zerosum_axes = [zerosum_axes ]
1485+ self .assert_zerosum_axes (random_samples , zerosum_axes )
14801486
14811487 new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
14821488 assert new_dist .eval ().shape == (5 , 3 )
1489+ random_samples = pm .draw (new_dist , draws = 100 )
1490+ self .assert_zerosum_axes (random_samples , zerosum_axes )
14831491
14841492 new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = True )
14851493 assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
1494+ random_samples = pm .draw (new_dist , draws = 100 )
1495+ self .assert_zerosum_axes (random_samples , zerosum_axes )
1496+
1497+ def assert_zerosum_axes (self , random_samples , zerosum_axes ):
1498+ for ax in zerosum_axes :
1499+ if ax < 0 :
1500+ assert np .isclose (
1501+ random_samples .mean (axis = ax ), 0
1502+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1503+ else :
1504+ ax = ax + 1
1505+ assert np .isclose (
1506+ random_samples .mean (axis = ax ), 0
1507+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
14861508
14871509
14881510class TestMvStudentTCov (BaseTestDistributionRandom ):
0 commit comments