@@ -1467,18 +1467,19 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
14671467
14681468 @pytest .mark .parametrize (
14691469 "zerosum_axes" ,
1470- [( - 1 ), ( - 2 ), ( 1 ), (( 0 , 1 )), (( - 2 , - 1 )) ],
1470+ [1 , 2 ],
14711471 )
14721472 def test_zsn_change_dist_size (self , zerosum_axes ):
14731473 base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), zerosum_axes = zerosum_axes )
14741474 random_samples = pm .draw (base_dist , draws = 100 )
14751475
1476- if not isinstance (zerosum_axes , (list , tuple )):
1477- zerosum_axes = [zerosum_axes ]
14781476 self .assert_zerosum_axes (random_samples , zerosum_axes )
14791477
14801478 new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
1481- assert new_dist .eval ().shape == (5 , 3 )
1479+ if zerosum_axes == 1 :
1480+ assert new_dist .eval ().shape == (5 , 3 , 9 )
1481+ elif zerosum_axes == 2 :
1482+ assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
14821483 random_samples = pm .draw (new_dist , draws = 100 )
14831484 self .assert_zerosum_axes (random_samples , zerosum_axes )
14841485
@@ -1488,16 +1489,11 @@ def test_zsn_change_dist_size(self, zerosum_axes):
14881489 self .assert_zerosum_axes (random_samples , zerosum_axes )
14891490
14901491 def assert_zerosum_axes (self , random_samples , zerosum_axes ):
1492+ zerosum_axes = np .arange (- zerosum_axes , 0 )
14911493 for ax in zerosum_axes :
1492- if ax < 0 :
1493- assert np .isclose (
1494- random_samples .mean (axis = ax ), 0
1495- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1496- else :
1497- ax = ax + 1
1498- assert np .isclose (
1499- random_samples .mean (axis = ax ), 0
1500- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1494+ assert np .isclose (
1495+ random_samples .mean (axis = ax ), 0
1496+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
15011497
15021498
15031499class TestMvStudentTCov (BaseTestDistributionRandom ):
0 commit comments