1818import warnings
1919
2020from functools import reduce
21+ from typing import Optional
2122
2223import aesara
2324import aesara .tensor as at
3637from aesara .tensor .random .utils import broadcast_params
3738from aesara .tensor .slinalg import Cholesky , SolveTriangular
3839from aesara .tensor .type import TensorType
39-
40- # from numpy.core.numeric import normalize_axis_tuple
4140from scipy import linalg , stats
4241
4342import pymc as pm
@@ -2412,20 +2411,24 @@ class ZeroSumNormal(Distribution):
24122411 It's actually the standard deviation of the underlying, unconstrained Normal distribution.
24132412 Defaults to 1 if not specified.
24142413 For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2415- zerosum_axes: list or tuple of strings or integers
2416- Axis (or axes) along which the zero-sum constraint is enforced.
2417- Defaults to [-1], i.e the last axis.
2418- If strings are passed, then ``dims`` is needed.
2419- Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions.
2420- dims: list or tuple of strings, optional
2421- The dimension names of the axes.
2422- Necessary when ``zerosum_axes`` is specified with strings.
2414+ zerosum_axes: int, defaults to 1
2415+ Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2416+ Defaults to 1, i.e the rightmost axis.
2417+ dims: sequence of strings, optional
2418+ Dimension names of the distribution. Works the same as for other PyMC distributions.
2419+ Necessary if ``shape`` is not passed.
2420+ shape: tuple of integers, optional
2421+ Shape of the distribution. Works the same as for other PyMC distributions.
2422+ Necessary if ``dims`` is not passed.
24232423
24242424 Warnings
24252425 --------
24262426 ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
24272427 The ability to specifiy a vector of ``sigma`` may be added in future versions.
24282428
2429+ ``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``,
2430+ just use ``pm.Normal``.
2431+
24292432 Examples
24302433 --------
24312434 .. code-block:: python
@@ -2444,23 +2447,21 @@ class ZeroSumNormal(Distribution):
24442447 """
24452448 rv_type = ZeroSumNormalRV
24462449
2447- # def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
2448- # dims = convert_dims(dims)
2449- # if zerosum_axes is None:
2450- # zerosum_axes = [-1]
2451- # if not isinstance(zerosum_axes, (list, tuple)):
2452- # zerosum_axes = [zerosum_axes]
2450+ def __new__ (cls , * args , zerosum_axes = None , support_shape = None , dims = None , ** kwargs ):
2451+ if dims is not None or kwargs .get ("observed" ) is not None :
2452+ zerosum_axes = cls .check_zerosum_axes (zerosum_axes )
24532453
2454- # if isinstance(zerosum_axes[0], str):
2455- # if not dims:
2456- # raise ValueError("You need to specify dims if zerosum_axes are strings.")
2457- # else:
2458- # zerosum_axes_ = []
2459- # for axis in zerosum_axes:
2460- # zerosum_axes_.append(dims.index(axis))
2461- # zerosum_axes = zerosum_axes_
2454+ support_shape = get_support_shape (
2455+ support_shape = support_shape ,
2456+ shape = None , # Shape will be checked in `cls.dist`
2457+ dims = dims ,
2458+ observed = kwargs .get ("observed" , None ),
2459+ ndim_supp = zerosum_axes ,
2460+ )
24622461
2463- # return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)
2462+ return super ().__new__ (
2463+ cls , * args , zerosum_axes = zerosum_axes , support_shape = support_shape , dims = dims , ** kwargs
2464+ )
24642465
24652466 @classmethod
24662467 def dist (cls , sigma = 1 , zerosum_axes = None , support_shape = None , ** kwargs ):
@@ -2480,10 +2481,13 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
24802481 shape = kwargs .get ("shape" ),
24812482 ndim_supp = zerosum_axes ,
24822483 )
2484+
2485+ # print(f"{support_shape.eval() = }")
2486+
24832487 if support_shape is None :
24842488 if zerosum_axes > 0 :
24852489 raise ValueError ("You must specify shape or support_shape parameter" )
2486- # edge case doesn't work for now, because at.stack in get_support_shape fails
2490+ # edge- case doesn't work for now, because at.stack in get_support_shape fails
24872491 # else:
24882492 # support_shape = () # because it's just a Normal in that case
24892493 support_shape = at .as_tensor_variable (intX (support_shape ))
@@ -2511,6 +2515,16 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
25112515 # check_dist_not_registered(dist)
25122516 # return super().dist([dist, lower, upper], **kwargs)
25132517
2518+ @classmethod
2519+ def check_zerosum_axes (cls , zerosum_axes : Optional [int ]) -> int :
2520+ if zerosum_axes is None :
2521+ zerosum_axes = 1
2522+ if not isinstance (zerosum_axes , int ):
2523+ raise TypeError ("zerosum_axes has to be an integer" )
2524+ if not zerosum_axes > 0 :
2525+ raise ValueError ("zerosum_axes has to be > 0" )
2526+ return zerosum_axes
2527+
25142528 @classmethod
25152529 def rv_op (cls , sigma , zerosum_axes , support_shape , size = None ):
25162530 # if size is None:
@@ -2553,11 +2567,14 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
25532567
25542568@_change_dist_size .register (ZeroSumNormalRV )
25552569def change_zerosum_size (op , normal_dist , new_size , expand = False ):
2570+
25562571 normal_dist , sigma , support_shape = normal_dist .owner .inputs
2572+
25572573 if expand :
25582574 original_shape = tuple (normal_dist .shape )
25592575 old_size = original_shape [len (original_shape ) - op .ndim_supp :]
25602576 new_size = tuple (new_size ) + old_size
2577+
25612578 return ZeroSumNormal .rv_op (
25622579 sigma = sigma , zerosum_axes = op .ndim_supp , support_shape = support_shape , size = new_size
25632580 )
@@ -2570,26 +2587,28 @@ def zerosumnormal_moment(op, rv, *rv_inputs):
25702587
25712588@_default_transform .register (ZeroSumNormalRV )
25722589def zerosum_default_transform (op , rv ):
2573- return ZeroSumTransform (op .zerosum_axes )
2590+ zerosum_axes = tuple (np .arange (- op .ndim_supp , 0 ))
2591+ return ZeroSumTransform (zerosum_axes )
25742592
25752593
25762594@_logprob .register (ZeroSumNormalRV )
25772595def zerosumnormal_logp (op , values , normal_dist , sigma , support_shape , ** kwargs ):
25782596 (value ,) = values
25792597 shape = value .shape
25802598 zerosum_axes = op .ndim_supp
2599+
25812600 _deg_free_support_shape = at .inc_subtensor (shape [- zerosum_axes :], - 1 )
25822601 _full_size = at .prod (shape )
25832602 _degrees_of_freedom = at .prod (_deg_free_support_shape )
2603+
25842604 zerosums = [
25852605 at .all (at .isclose (at .mean (value , axis = - axis - 1 ), 0 , atol = 1e-9 ))
25862606 for axis in range (zerosum_axes )
25872607 ]
2608+
25882609 out = at .sum (
25892610 pm .logp (normal_dist , value ) * _degrees_of_freedom / _full_size ,
25902611 axis = tuple (np .arange (- zerosum_axes , 0 )),
25912612 )
2592- # figure out how dimensionality should be handled for logp
2593- # for now, we assume ZSN is a scalar distribut, which is not correct
2594- # out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
2613+
25952614 return check_parameters (out , * zerosums , msg = "at.mean(value, axis=zerosum_axes) == 0" )
0 commit comments