@@ -1388,6 +1388,18 @@ def test_issue_3706(self):
13881388
13891389
13901390class TestZeroSumNormal :
1391+ def assert_zerosum_axes (self , random_samples , axes_to_check , check_zerosum_axes = True ):
1392+ if check_zerosum_axes :
1393+ for ax in axes_to_check :
1394+ assert np .isclose (
1395+ random_samples .mean (axis = ax ), 0
1396+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1397+ else :
1398+ for ax in axes_to_check :
1399+ assert not np .isclose (
1400+ random_samples .mean (axis = ax ), 0
1401+ ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1402+
13911403 @pytest .mark .parametrize (
13921404 "dims, zerosum_axes" ,
13931405 [
@@ -1504,18 +1516,6 @@ def test_zsn_change_dist_size(self, zerosum_axes):
15041516 random_samples = pm .draw (new_dist , draws = 100 )
15051517 self .assert_zerosum_axes (random_samples , zerosum_axes )
15061518
1507- def assert_zerosum_axes (self , random_samples , axes_to_check , check_zerosum_axes = True ):
1508- if check_zerosum_axes :
1509- for ax in axes_to_check :
1510- assert np .isclose (
1511- random_samples .mean (axis = ax ), 0
1512- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1513- else :
1514- for ax in axes_to_check :
1515- assert not np .isclose (
1516- random_samples .mean (axis = ax ), 0
1517- ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1518-
15191519 @pytest .mark .parametrize (
15201520 "sigma, n" ,
15211521 [
0 commit comments