Skip to content

Commit 854ef4c

Browse files
committed
Move ZSN to multivariate.py
1 parent 0bdcdd7 commit 854ef4c

File tree

3 files changed

+175
-187
lines changed

3 files changed

+175
-187
lines changed

pymc/distributions/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
VonMises,
5757
Wald,
5858
Weibull,
59-
ZeroSumNormal,
6059
)
6160
from pymc.distributions.discrete import (
6261
Bernoulli,
@@ -100,6 +99,7 @@
10099
StickBreakingWeights,
101100
Wishart,
102101
WishartBartlett,
102+
ZeroSumNormal,
103103
)
104104
from pymc.distributions.simulator import Simulator
105105
from pymc.distributions.timeseries import (
@@ -118,7 +118,6 @@
118118
"HalfFlat",
119119
"Normal",
120120
"TruncatedNormal",
121-
"ZeroSumNormal",
122121
"Beta",
123122
"Kumaraswamy",
124123
"Exponential",
@@ -161,6 +160,7 @@
161160
"Continuous",
162161
"Discrete",
163162
"MvNormal",
163+
"ZeroSumNormal",
164164
"MatrixNormal",
165165
"KroneckerNormal",
166166
"MvStudentT",

pymc/distributions/continuous.py

Lines changed: 3 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,10 @@ def polyagamma_cdf(*args, **kwargs):
6969
raise RuntimeError("polyagamma package is not installed!")
7070

7171

72-
from numpy.core.numeric import normalize_axis_tuple
7372
from scipy import stats
7473
from scipy.interpolate import InterpolatedUnivariateSpline
7574
from scipy.special import expit
7675

77-
import pymc as pm
78-
7976
from pymc.aesaraf import floatX
8077
from pymc.distributions import transforms
8178
from pymc.distributions.dist_math import (
@@ -89,28 +86,16 @@ def polyagamma_cdf(*args, **kwargs):
8986
normal_lcdf,
9087
zvalue,
9188
)
92-
from pymc.distributions.distribution import (
93-
DIST_PARAMETER_TYPES,
94-
Continuous,
95-
Distribution,
96-
SymbolicRandomVariable,
97-
_moment,
98-
)
99-
from pymc.distributions.logprob import ignore_logprob
100-
from pymc.distributions.shape_utils import (
101-
_change_dist_size,
102-
convert_dims,
103-
rv_size_is_none,
104-
)
105-
from pymc.distributions.transforms import ZeroSumTransform, _default_transform
89+
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous
90+
from pymc.distributions.shape_utils import rv_size_is_none
91+
from pymc.distributions.transforms import _default_transform
10692
from pymc.math import invlogit, logdiffexp, logit
10793

10894
__all__ = [
10995
"Uniform",
11096
"Flat",
11197
"HalfFlat",
11298
"Normal",
113-
"ZeroSumNormal",
11499
"TruncatedNormal",
115100
"Beta",
116101
"Kumaraswamy",
@@ -600,172 +585,6 @@ def logcdf(value, mu, sigma):
600585
)
601586

602587

603-
class ZeroSumNormalRV(SymbolicRandomVariable):
604-
"""ZeroSumNormal random variable"""
605-
606-
_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
607-
zerosum_axes = None
608-
609-
def __init__(self, *args, zerosum_axes, **kwargs):
610-
self.zerosum_axes = zerosum_axes
611-
super().__init__(*args, **kwargs)
612-
613-
614-
class ZeroSumNormal(Distribution):
615-
r"""
616-
ZeroSumNormal distribution, i.e Normal distribution where one or
617-
several axes are constrained to sum to zero.
618-
By default, the last axis is constrained to sum to zero.
619-
See `zerosum_axes` kwarg for more details.
620-
621-
Parameters
622-
----------
623-
sigma : tensor_like of float
624-
Standard deviation (sigma > 0).
625-
Defaults to 1 if not specified.
626-
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
627-
zerosum_axes: list or tuple of strings or integers
628-
Axis (or axes) along which the zero-sum constraint is enforced.
629-
Defaults to [-1], i.e the last axis.
630-
If strings are passed, then ``dims`` is needed.
631-
Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions.
632-
dims: list or tuple of strings, optional
633-
The dimension names of the axes.
634-
Necessary when ``zerosum_axes`` is specified with strings.
635-
636-
Warnings
637-
--------
638-
``sigma`` has to be a scalar, to ensure the zero-sum constraint.
639-
The ability to specifiy a vector of ``sigma`` may be added in future versions.
640-
641-
Examples
642-
--------
643-
.. code-block:: python
644-
COORDS = {
645-
"regions": ["a", "b", "c"],
646-
"answers": ["yes", "no", "whatever", "don't understand question"],
647-
}
648-
with pm.Model(coords=COORDS) as m:
649-
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers")
650-
651-
with pm.Model(coords=COORDS) as m:
652-
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers"))
653-
654-
with pm.Model(coords=COORDS) as m:
655-
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1)
656-
"""
657-
rv_type = ZeroSumNormalRV
658-
659-
def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
660-
dims = convert_dims(dims)
661-
if zerosum_axes is None:
662-
zerosum_axes = [-1]
663-
if not isinstance(zerosum_axes, (list, tuple)):
664-
zerosum_axes = [zerosum_axes]
665-
666-
if isinstance(zerosum_axes[0], str):
667-
if not dims:
668-
raise ValueError("You need to specify dims if zerosum_axes are strings.")
669-
else:
670-
zerosum_axes_ = []
671-
for axis in zerosum_axes:
672-
zerosum_axes_.append(dims.index(axis))
673-
zerosum_axes = zerosum_axes_
674-
675-
return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)
676-
677-
@classmethod
678-
def dist(cls, sigma=1, zerosum_axes=None, **kwargs):
679-
if zerosum_axes is None:
680-
zerosum_axes = [-1]
681-
682-
sigma = at.as_tensor_variable(floatX(sigma))
683-
if sigma.ndim > 0:
684-
raise ValueError("sigma has to be a scalar")
685-
686-
return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs)
687-
688-
# TODO: This is if we want ZeroSum constraint on other dists than Normal
689-
# def dist(cls, dist, lower, upper, **kwargs):
690-
# if not isinstance(dist, TensorVariable) or not isinstance(
691-
# dist.owner.op, (RandomVariable, SymbolicRandomVariable)
692-
# ):
693-
# raise ValueError(
694-
# f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
695-
# )
696-
# if dist.owner.op.ndim_supp > 0:
697-
# raise NotImplementedError(
698-
# "Censoring of multivariate distributions has not been implemented yet"
699-
# )
700-
# check_dist_not_registered(dist)
701-
# return super().dist([dist, lower, upper], **kwargs)
702-
703-
@classmethod
704-
def rv_op(cls, sigma, zerosum_axes, size=None):
705-
if size is None:
706-
zerosum_axes_ = np.asarray(zerosum_axes)
707-
# just a placeholder size to infer minimum shape
708-
size = np.ones(
709-
max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
710-
).tolist()
711-
712-
# check if zerosum_axes is valid
713-
normalize_axis_tuple(zerosum_axes, len(size))
714-
715-
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size))
716-
normal_dist_, sigma_ = normal_dist.type(), sigma.type()
717-
718-
# Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes
719-
zerosum_rv_ = normal_dist_
720-
for axis in zerosum_axes:
721-
zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True)
722-
723-
return ZeroSumNormalRV(
724-
inputs=[normal_dist_, sigma_],
725-
outputs=[zerosum_rv_],
726-
zerosum_axes=zerosum_axes,
727-
ndim_supp=0,
728-
)(normal_dist, sigma)
729-
730-
731-
@_change_dist_size.register(ZeroSumNormalRV)
732-
def change_zerosum_size(op, normal_dist, new_size, expand=False):
733-
normal_dist, sigma = normal_dist.owner.inputs
734-
if expand:
735-
new_size = tuple(new_size) + tuple(normal_dist.shape)
736-
return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size)
737-
738-
739-
@_moment.register(ZeroSumNormalRV)
740-
def zerosumnormal_moment(op, rv, *rv_inputs):
741-
return at.zeros_like(rv)
742-
743-
744-
@_default_transform.register(ZeroSumNormalRV)
745-
def zerosum_default_transform(op, rv):
746-
return ZeroSumTransform(op.zerosum_axes)
747-
748-
749-
@_logprob.register(ZeroSumNormalRV)
750-
def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs):
751-
(value,) = values
752-
shape = value.shape
753-
_deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1)
754-
_full_size = at.prod(shape)
755-
_degrees_of_freedom = at.prod(_deg_free_shape)
756-
zerosums = [
757-
at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes
758-
]
759-
# out = at.sum(
760-
# pm.logp(dist, value) * _degrees_of_freedom / _full_size,
761-
# axis=op.zerosum_axes,
762-
# )
763-
# figure out how dimensionality should be handled for logp
764-
# for now, we assume ZSN is a scalar distribut, which is not correct
765-
out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
766-
return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0")
767-
768-
769588
class TruncatedNormalRV(RandomVariable):
770589
name = "truncated_normal"
771590
ndim_supp = 0

0 commit comments

Comments
 (0)