@@ -1543,6 +1543,52 @@ def test_zsn_variance(self, sigma, n):
15431543
15441544 np .testing .assert_allclose (empirical_var , theoretical_var , rtol = 1e-02 )
15451545
1546+ @pytest .mark .parametrize (
1547+ "sigma, shape, zerosum_axes, mvn_axes" ,
1548+ [
1549+ (5 , 3 , None , [- 1 ]),
1550+ (2 , 6 , None , [- 1 ]),
1551+ (5 , (7 , 3 ), None , [- 1 ]),
1552+ (5 , (2 , 7 , 3 ), 2 , [1 , 2 ]),
1553+ ],
1554+ )
1555+ def test_zsn_logp (self , sigma , shape , zerosum_axes , mvn_axes ):
1556+
1557+ zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , zerosum_axes = zerosum_axes )
1558+ zsn_logp = pm .logp (zsn_dist , value = np .zeros (shape )).eval ()
1559+ mvn_logp = self .logp_norm (value = np .zeros (shape ), sigma = sigma , axes = mvn_axes )
1560+
1561+ np .testing .assert_allclose (zsn_logp , mvn_logp )
1562+
1563+ def logp_norm (self , value , sigma , axes ):
1564+ """
1565+ Special case of the MvNormal, that's equivalent to the ZSN.
1566+ Only to test the ZSN logp
1567+ """
1568+ axes = [ax if ax >= 0 else value .ndim + ax for ax in axes ]
1569+ if len (set (axes )) < len (axes ):
1570+ raise ValueError ("Must specify unique zero sum axes" )
1571+ other_axes = [ax for ax in range (value .ndim ) if ax not in axes ]
1572+ new_order = other_axes + axes
1573+ reshaped_value = np .reshape (
1574+ np .transpose (value , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1575+ )
1576+
1577+ degrees_of_freedom = np .prod ([value .shape [ax ] - 1 for ax in axes ])
1578+ full_size = np .prod ([value .shape [ax ] for ax in axes ])
1579+
1580+ ns = value .shape [- 1 ]
1581+ psdet = (0.5 * np .log (2 * np .pi ) + np .log (sigma )) * degrees_of_freedom / full_size
1582+ exp = 0.5 * (reshaped_value / sigma ) ** 2
1583+ inds = np .ones_like (value , dtype = "bool" )
1584+ for ax in axes :
1585+ inds = np .logical_and (inds , np .abs (np .mean (value , axis = ax , keepdims = True )) < 1e-9 )
1586+ inds = np .reshape (
1587+ np .transpose (inds , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1588+ )[..., 0 ]
1589+
1590+ return np .where (inds , np .sum (- psdet - exp , axis = - 1 ), - np .inf )
1591+
15461592
15471593class TestMvStudentTCov (BaseTestDistributionRandom ):
15481594 def mvstudentt_rng_fn (self , size , nu , mu , cov , rng ):
0 commit comments