@@ -1418,27 +1418,15 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14181418
14191419 assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
14201420
1421- zerosum_axes = np .arange (- v .owner .op .ndim_supp , 0 )
1422- nonzero_axes = np .arange (v .ndim - v .owner .op .ndim_supp )
1423-
1424- for ax in zerosum_axes :
1425- for samples in [
1426- s .posterior .v .mean (axis = ax ),
1427- random_samples .mean (axis = ax ),
1428- ]:
1429- assert np .isclose (
1430- samples , 0
1431- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1432-
1433- if nonzero_axes :
1434- for ax in nonzero_axes :
1435- for samples in [
1436- s .posterior .v .mean (axis = ax ),
1437- random_samples .mean (axis = ax ),
1438- ]:
1439- assert not np .isclose (
1440- samples , 0
1441- ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1421+ ndim_supp = v .owner .op .ndim_supp
1422+ zerosum_axes = np .arange (- ndim_supp , 0 )
1423+ nonzero_axes = np .arange (v .ndim - ndim_supp )
1424+ for samples in [
1425+ s .posterior .v ,
1426+ random_samples ,
1427+ ]:
1428+ self .assert_zerosum_axes (samples , zerosum_axes )
1429+ self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
14421430
14431431 @pytest .mark .parametrize (
14441432 "error, match, shape, support_shape, zerosum_axes" ,
@@ -1473,6 +1461,7 @@ def test_zsn_change_dist_size(self, zerosum_axes):
14731461 base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), zerosum_axes = zerosum_axes )
14741462 random_samples = pm .draw (base_dist , draws = 100 )
14751463
1464+ zerosum_axes = np .arange (- zerosum_axes , 0 )
14761465 self .assert_zerosum_axes (random_samples , zerosum_axes )
14771466
14781467 new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
@@ -1488,12 +1477,17 @@ def test_zsn_change_dist_size(self, zerosum_axes):
14881477 random_samples = pm .draw (new_dist , draws = 100 )
14891478 self .assert_zerosum_axes (random_samples , zerosum_axes )
14901479
1491- def assert_zerosum_axes (self , random_samples , zerosum_axes ):
1492- zerosum_axes = np .arange (- zerosum_axes , 0 )
1493- for ax in zerosum_axes :
1494- assert np .isclose (
1495- random_samples .mean (axis = ax ), 0
1496- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1480+ def assert_zerosum_axes (self , random_samples , axes_to_check , check_zerosum_axes = True ):
1481+ if check_zerosum_axes :
1482+ for ax in axes_to_check :
1483+ assert np .isclose (
1484+ random_samples .mean (axis = ax ), 0
1485+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1486+ else :
1487+ for ax in axes_to_check :
1488+ assert not np .isclose (
1489+ random_samples .mean (axis = ax ), 0
1490+ ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
14971491
14981492
14991493class TestMvStudentTCov (BaseTestDistributionRandom ):
0 commit comments