@@ -1432,14 +1432,14 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14321432 "error, match, shape, support_shape, zerosum_axes" ,
14331433 [
14341434 (IndexError , "index out of range" , (3 , 4 , 5 ), None , 4 ),
1435- (AssertionError , "does not match" , (3 , 4 ), 3 , None ), # support_shape should be 4
1435+ (AssertionError , "does not match" , (3 , 4 ), ( 3 ,) , None ), # support_shape should be 4
14361436 (
14371437 AssertionError ,
14381438 "does not match" ,
14391439 (3 , 4 ),
14401440 (3 , 4 ),
14411441 None ,
1442- ), # doesn't work because zerosum_axes = 1
1442+ ), # doesn't work because zerosum_axes = 1 by default
14431443 ],
14441444 )
14451445 def test_zsn_fail_axis (self , error , match , shape , support_shape , zerosum_axes ):
@@ -1449,9 +1449,20 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
14491449 "v" , shape = shape , support_shape = support_shape , zerosum_axes = zerosum_axes
14501450 )
14511451
1452- # v = pm.ZeroSumNormal("v", support_shape=(3, 4), zerosum_axes=2) # should work
1452+ @pytest .mark .parametrize (
1453+ "shape, support_shape" ,
1454+ [
1455+ (None , (3 , 4 )),
1456+ ((3 , 4 ), (3 , 4 )),
1457+ ],
1458+ )
1459+ def test_zsn_support_shape (self , shape , support_shape ):
1460+ with pm .Model () as m :
1461+ v = pm .ZeroSumNormal ("v" , shape = shape , support_shape = support_shape , zerosum_axes = 2 )
14531462
1454- # v = pm.ZeroSumNormal("v", shape=(3, 4), support_shape=(3, 4), zerosum_axes=2) this should work but doesn't
1463+ random_samples = pm .draw (v , draws = 10 )
1464+ zerosum_axes = np .arange (- 2 , 0 )
1465+ self .assert_zerosum_axes (random_samples , zerosum_axes )
14551466
14561467 @pytest .mark .parametrize (
14571468 "zerosum_axes" ,
@@ -1465,9 +1476,9 @@ def test_zsn_change_dist_size(self, zerosum_axes):
14651476 self .assert_zerosum_axes (random_samples , zerosum_axes )
14661477
14671478 new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
1468- if zerosum_axes == 1 :
1479+ try :
14691480 assert new_dist .eval ().shape == (5 , 3 , 9 )
1470- elif zerosum_axes == 2 :
1481+ except AssertionError :
14711482 assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
14721483 random_samples = pm .draw (new_dist , draws = 100 )
14731484 self .assert_zerosum_axes (random_samples , zerosum_axes )
0 commit comments