2828from aeppl .logprob import ParameterValueError
2929from aesara .tensor import TensorVariable
3030from aesara .tensor .random .utils import broadcast_params
31+ from numpy import AxisError
3132
3233import pymc as pm
3334
@@ -754,7 +755,12 @@ def test_car_logp(self, sparse, size):
754755
755756 # d x d adjacency matrix for a square (d=4) of rook-adjacent sites
756757 W = np .array (
757- [[0.0 , 1.0 , 1.0 , 0.0 ], [1.0 , 0.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 0.0 , 1.0 ], [0.0 , 1.0 , 1.0 , 0.0 ]]
758+ [
759+ [0.0 , 1.0 , 1.0 , 0.0 ],
760+ [1.0 , 0.0 , 0.0 , 1.0 ],
761+ [1.0 , 0.0 , 0.0 , 1.0 ],
762+ [0.0 , 1.0 , 1.0 , 0.0 ],
763+ ]
758764 )
759765
760766 tau = 2
@@ -1007,6 +1013,19 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
10071013 # MvNormal logp is only implemented for up to 2D variables
10081014 assert_moment_is_expected (model , expected , check_finite_logp = x .ndim < 3 )
10091015
1016+ @pytest .mark .parametrize (
1017+ "shape, zerosum_axes, expected" ,
1018+ [
1019+ ((2 , 5 ), None , np .zeros ((2 , 5 ))),
1020+ ((2 , 5 , 6 ), None , np .zeros ((2 , 5 , 6 ))),
1021+ ((2 , 5 , 6 ), (0 , 1 ), np .zeros ((2 , 5 , 6 ))),
1022+ ],
1023+ )
1024+ def test_zerosum_normal_moment (self , shape , zerosum_axes , expected ):
1025+ with pm .Model () as model :
1026+ pm .ZeroSumNormal ("x" , shape = shape , zerosum_axes = zerosum_axes )
1027+ assert_moment_is_expected (model , expected )
1028+
10101029 @pytest .mark .parametrize (
10111030 "mu, size, expected" ,
10121031 [
@@ -1026,7 +1045,12 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
10261045 )
10271046 def test_car_moment (self , mu , size , expected ):
10281047 W = np .array (
1029- [[0.0 , 1.0 , 1.0 , 0.0 ], [1.0 , 0.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 0.0 , 1.0 ], [0.0 , 1.0 , 1.0 , 0.0 ]]
1048+ [
1049+ [0.0 , 1.0 , 1.0 , 0.0 ],
1050+ [1.0 , 0.0 , 0.0 , 1.0 ],
1051+ [1.0 , 0.0 , 0.0 , 1.0 ],
1052+ [0.0 , 1.0 , 1.0 , 0.0 ],
1053+ ]
10301054 )
10311055 tau = 2
10321056 alpha = 0.5
@@ -1367,6 +1391,100 @@ def test_issue_3706(self):
13671391 assert prior_pred ["X" ].shape == (1 , N , 2 )
13681392
13691393
1394+ COORDS = {
1395+ "regions" : ["a" , "b" , "c" ],
1396+ "answers" : ["yes" , "no" , "whatever" , "don't understand question" ],
1397+ }
1398+
1399+
1400+ class TestZeroSumNormal :
1401+ @pytest .mark .parametrize (
1402+ "dims, zerosum_axes, shape" ,
1403+ [
1404+ (("regions" , "answers" ), "answers" , None ),
1405+ (("regions" , "answers" ), ("regions" , "answers" ), None ),
1406+ (("regions" , "answers" ), 0 , None ),
1407+ (("regions" , "answers" ), - 1 , None ),
1408+ (("regions" , "answers" ), (0 , 1 ), None ),
1409+ (None , - 2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1410+ ],
1411+ )
1412+ def test_zsn_dims_shape (self , dims , zerosum_axes , shape ):
1413+ with pm .Model (coords = COORDS ) as m :
1414+ v = pm .ZeroSumNormal ("v" , dims = dims , shape = shape , zerosum_axes = zerosum_axes )
1415+ s = pm .sample (10 , chains = 1 , tune = 100 )
1416+
1417+ # to test forward graph
1418+ random_samples = pm .draw (
1419+ v ,
1420+ draws = 10 ,
1421+ )
1422+
1423+ assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
1424+
1425+ if not isinstance (zerosum_axes , (list , tuple )):
1426+ zerosum_axes = [zerosum_axes ]
1427+
1428+ if isinstance (zerosum_axes [0 ], str ):
1429+ for ax in zerosum_axes :
1430+ for samples in [
1431+ s .posterior .v .mean (dim = ax ),
1432+ random_samples .mean (axis = dims .index (ax ) + 1 ),
1433+ ]:
1434+ assert np .isclose (
1435+ samples , 0
1436+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1437+
1438+ nonzero_axes = list (set (dims ).difference (zerosum_axes ))
1439+ if nonzero_axes :
1440+ for ax in nonzero_axes :
1441+ for samples in [
1442+ s .posterior .v .mean (dim = ax ),
1443+ random_samples .mean (axis = dims .index (ax ) + 1 ),
1444+ ]:
1445+ assert not np .isclose (
1446+ samples , 0
1447+ ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1448+
1449+ else :
1450+ for ax in zerosum_axes :
1451+ if ax < 0 :
1452+ assert np .isclose (
1453+ s .posterior .v .mean (axis = ax ), 0
1454+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1455+ else :
1456+ ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling
1457+ assert np .isclose (
1458+ s .posterior .v .mean (axis = ax ), 0
1459+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1460+
1461+ @pytest .mark .parametrize (
1462+ "dims, zerosum_axes" ,
1463+ [
1464+ (("regions" , "answers" ), 2 ),
1465+ (("regions" , "answers" ), (0 , - 2 )),
1466+ ],
1467+ )
1468+ def test_zsn_fail_axis (self , dims , zerosum_axes ):
1469+ if isinstance (zerosum_axes , (list , tuple )):
1470+ with pytest .raises (ValueError , match = "repeated axis" ):
1471+ with pm .Model (coords = COORDS ) as m :
1472+ _ = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
1473+ else :
1474+ with pytest .raises (AxisError , match = "out of bounds" ):
1475+ with pm .Model (coords = COORDS ) as m :
1476+ _ = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
1477+
1478+ def test_zsn_change_dist_size (self ):
1479+ base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ))
1480+
1481+ new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
1482+ assert new_dist .eval ().shape == (5 , 3 )
1483+
1484+ new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = True )
1485+ assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
1486+
1487+
13701488class TestMvStudentTCov (BaseTestDistributionRandom ):
13711489 def mvstudentt_rng_fn (self , size , nu , mu , cov , rng ):
13721490 mv_samples = rng .multivariate_normal (np .zeros_like (mu ), cov , size = size )
0 commit comments