@@ -1543,41 +1543,39 @@ def test_zsn_variance(self, sigma, n):
15431543 ],
15441544 )
15451545 def test_zsn_logp (self , sigma , shape , zerosum_axes , mvn_axes ):
1546+ def logp_norm (value , sigma , axes ):
1547+ """
1548+ Special case of the MvNormal, that's equivalent to the ZSN.
1549+ Only to test the ZSN logp
1550+ """
1551+ axes = [ax if ax >= 0 else value .ndim + ax for ax in axes ]
1552+ if len (set (axes )) < len (axes ):
1553+ raise ValueError ("Must specify unique zero sum axes" )
1554+ other_axes = [ax for ax in range (value .ndim ) if ax not in axes ]
1555+ new_order = other_axes + axes
1556+ reshaped_value = np .reshape (
1557+ np .transpose (value , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1558+ )
15461559
1547- zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , zerosum_axes = zerosum_axes )
1548- zsn_logp = pm .logp (zsn_dist , value = np .zeros (shape )).eval ()
1549- mvn_logp = self .logp_norm (value = np .zeros (shape ), sigma = sigma , axes = mvn_axes )
1560+ degrees_of_freedom = np .prod ([value .shape [ax ] - 1 for ax in axes ])
1561+ full_size = np .prod ([value .shape [ax ] for ax in axes ])
15501562
1551- np .testing .assert_allclose (zsn_logp , mvn_logp )
1563+ psdet = (0.5 * np .log (2 * np .pi ) + np .log (sigma )) * degrees_of_freedom / full_size
1564+ exp = 0.5 * (reshaped_value / sigma ) ** 2
1565+ inds = np .ones_like (value , dtype = "bool" )
1566+ for ax in axes :
1567+ inds = np .logical_and (inds , np .abs (np .mean (value , axis = ax , keepdims = True )) < 1e-9 )
1568+ inds = np .reshape (
1569+ np .transpose (inds , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1570+ )[..., 0 ]
15521571
1553- def logp_norm (self , value , sigma , axes ):
1554- """
1555- Special case of the MvNormal, that's equivalent to the ZSN.
1556- Only to test the ZSN logp
1557- """
1558- axes = [ax if ax >= 0 else value .ndim + ax for ax in axes ]
1559- if len (set (axes )) < len (axes ):
1560- raise ValueError ("Must specify unique zero sum axes" )
1561- other_axes = [ax for ax in range (value .ndim ) if ax not in axes ]
1562- new_order = other_axes + axes
1563- reshaped_value = np .reshape (
1564- np .transpose (value , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1565- )
1566-
1567- degrees_of_freedom = np .prod ([value .shape [ax ] - 1 for ax in axes ])
1568- full_size = np .prod ([value .shape [ax ] for ax in axes ])
1572+ return np .where (inds , np .sum (- psdet - exp , axis = - 1 ), - np .inf )
15691573
1570- ns = value .shape [- 1 ]
1571- psdet = (0.5 * np .log (2 * np .pi ) + np .log (sigma )) * degrees_of_freedom / full_size
1572- exp = 0.5 * (reshaped_value / sigma ) ** 2
1573- inds = np .ones_like (value , dtype = "bool" )
1574- for ax in axes :
1575- inds = np .logical_and (inds , np .abs (np .mean (value , axis = ax , keepdims = True )) < 1e-9 )
1576- inds = np .reshape (
1577- np .transpose (inds , new_order ), [value .shape [ax ] for ax in other_axes ] + [- 1 ]
1578- )[..., 0 ]
1574+ zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , zerosum_axes = zerosum_axes )
1575+ zsn_logp = pm .logp (zsn_dist , value = np .zeros (shape )).eval ()
1576+ mvn_logp = logp_norm (value = np .zeros (shape ), sigma = sigma , axes = mvn_axes )
15791577
1580- return np .where ( inds , np . sum ( - psdet - exp , axis = - 1 ), - np . inf )
1578+ np .testing . assert_allclose ( zsn_logp , mvn_logp )
15811579
15821580
15831581class TestMvStudentTCov (BaseTestDistributionRandom ):
0 commit comments