@@ -69,13 +69,10 @@ def polyagamma_cdf(*args, **kwargs):
6969 raise RuntimeError ("polyagamma package is not installed!" )
7070
7171
72- from numpy .core .numeric import normalize_axis_tuple
7372from scipy import stats
7473from scipy .interpolate import InterpolatedUnivariateSpline
7574from scipy .special import expit
7675
77- import pymc as pm
78-
7976from pymc .aesaraf import floatX
8077from pymc .distributions import transforms
8178from pymc .distributions .dist_math import (
@@ -89,28 +86,16 @@ def polyagamma_cdf(*args, **kwargs):
8986 normal_lcdf ,
9087 zvalue ,
9188)
92- from pymc .distributions .distribution import (
93- DIST_PARAMETER_TYPES ,
94- Continuous ,
95- Distribution ,
96- SymbolicRandomVariable ,
97- _moment ,
98- )
99- from pymc .distributions .logprob import ignore_logprob
100- from pymc .distributions .shape_utils import (
101- _change_dist_size ,
102- convert_dims ,
103- rv_size_is_none ,
104- )
105- from pymc .distributions .transforms import ZeroSumTransform , _default_transform
89+ from pymc .distributions .distribution import DIST_PARAMETER_TYPES , Continuous
90+ from pymc .distributions .shape_utils import rv_size_is_none
91+ from pymc .distributions .transforms import _default_transform
10692from pymc .math import invlogit , logdiffexp , logit
10793
10894__all__ = [
10995 "Uniform" ,
11096 "Flat" ,
11197 "HalfFlat" ,
11298 "Normal" ,
113- "ZeroSumNormal" ,
11499 "TruncatedNormal" ,
115100 "Beta" ,
116101 "Kumaraswamy" ,
@@ -600,172 +585,6 @@ def logcdf(value, mu, sigma):
600585 )
601586
602587
603- class ZeroSumNormalRV (SymbolicRandomVariable ):
604- """ZeroSumNormal random variable"""
605-
606- _print_name = ("ZeroSumNormal" , "\\ operatorname{ZeroSumNormal}" )
607- zerosum_axes = None
608-
609- def __init__ (self , * args , zerosum_axes , ** kwargs ):
610- self .zerosum_axes = zerosum_axes
611- super ().__init__ (* args , ** kwargs )
612-
613-
614- class ZeroSumNormal (Distribution ):
615- r"""
616- ZeroSumNormal distribution, i.e Normal distribution where one or
617- several axes are constrained to sum to zero.
618- By default, the last axis is constrained to sum to zero.
619- See `zerosum_axes` kwarg for more details.
620-
621- Parameters
622- ----------
623- sigma : tensor_like of float
624- Standard deviation (sigma > 0).
625- Defaults to 1 if not specified.
626- For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
627- zerosum_axes: list or tuple of strings or integers
628- Axis (or axes) along which the zero-sum constraint is enforced.
629- Defaults to [-1], i.e the last axis.
630- If strings are passed, then ``dims`` is needed.
631- Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions.
632- dims: list or tuple of strings, optional
633- The dimension names of the axes.
634- Necessary when ``zerosum_axes`` is specified with strings.
635-
636- Warnings
637- --------
638- ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
639- The ability to specifiy a vector of ``sigma`` may be added in future versions.
640-
641- Examples
642- --------
643- .. code-block:: python
644- COORDS = {
645- "regions": ["a", "b", "c"],
646- "answers": ["yes", "no", "whatever", "don't understand question"],
647- }
648- with pm.Model(coords=COORDS) as m:
649- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers")
650-
651- with pm.Model(coords=COORDS) as m:
652- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers"))
653-
654- with pm.Model(coords=COORDS) as m:
655- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1)
656- """
657- rv_type = ZeroSumNormalRV
658-
659- def __new__ (cls , * args , zerosum_axes = None , dims = None , ** kwargs ):
660- dims = convert_dims (dims )
661- if zerosum_axes is None :
662- zerosum_axes = [- 1 ]
663- if not isinstance (zerosum_axes , (list , tuple )):
664- zerosum_axes = [zerosum_axes ]
665-
666- if isinstance (zerosum_axes [0 ], str ):
667- if not dims :
668- raise ValueError ("You need to specify dims if zerosum_axes are strings." )
669- else :
670- zerosum_axes_ = []
671- for axis in zerosum_axes :
672- zerosum_axes_ .append (dims .index (axis ))
673- zerosum_axes = zerosum_axes_
674-
675- return super ().__new__ (cls , * args , zerosum_axes = zerosum_axes , dims = dims , ** kwargs )
676-
677- @classmethod
678- def dist (cls , sigma = 1 , zerosum_axes = None , ** kwargs ):
679- if zerosum_axes is None :
680- zerosum_axes = [- 1 ]
681-
682- sigma = at .as_tensor_variable (floatX (sigma ))
683- if sigma .ndim > 0 :
684- raise ValueError ("sigma has to be a scalar" )
685-
686- return super ().dist ([sigma ], zerosum_axes = zerosum_axes , ** kwargs )
687-
688- # TODO: This is if we want ZeroSum constraint on other dists than Normal
689- # def dist(cls, dist, lower, upper, **kwargs):
690- # if not isinstance(dist, TensorVariable) or not isinstance(
691- # dist.owner.op, (RandomVariable, SymbolicRandomVariable)
692- # ):
693- # raise ValueError(
694- # f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
695- # )
696- # if dist.owner.op.ndim_supp > 0:
697- # raise NotImplementedError(
698- # "Censoring of multivariate distributions has not been implemented yet"
699- # )
700- # check_dist_not_registered(dist)
701- # return super().dist([dist, lower, upper], **kwargs)
702-
703- @classmethod
704- def rv_op (cls , sigma , zerosum_axes , size = None ):
705- if size is None :
706- zerosum_axes_ = np .asarray (zerosum_axes )
707- # just a placeholder size to infer minimum shape
708- size = np .ones (
709- max ((max (np .abs (zerosum_axes_ ) - 1 ), max (zerosum_axes_ ))) + 1 , dtype = int
710- ).tolist ()
711-
712- # check if zerosum_axes is valid
713- normalize_axis_tuple (zerosum_axes , len (size ))
714-
715- normal_dist = ignore_logprob (pm .Normal .dist (sigma = sigma , size = size ))
716- normal_dist_ , sigma_ = normal_dist .type (), sigma .type ()
717-
718- # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes
719- zerosum_rv_ = normal_dist_
720- for axis in zerosum_axes :
721- zerosum_rv_ -= zerosum_rv_ .mean (axis = axis , keepdims = True )
722-
723- return ZeroSumNormalRV (
724- inputs = [normal_dist_ , sigma_ ],
725- outputs = [zerosum_rv_ ],
726- zerosum_axes = zerosum_axes ,
727- ndim_supp = 0 ,
728- )(normal_dist , sigma )
729-
730-
731- @_change_dist_size .register (ZeroSumNormalRV )
732- def change_zerosum_size (op , normal_dist , new_size , expand = False ):
733- normal_dist , sigma = normal_dist .owner .inputs
734- if expand :
735- new_size = tuple (new_size ) + tuple (normal_dist .shape )
736- return ZeroSumNormal .rv_op (sigma = sigma , zerosum_axes = op .zerosum_axes , size = new_size )
737-
738-
739- @_moment .register (ZeroSumNormalRV )
740- def zerosumnormal_moment (op , rv , * rv_inputs ):
741- return at .zeros_like (rv )
742-
743-
744- @_default_transform .register (ZeroSumNormalRV )
745- def zerosum_default_transform (op , rv ):
746- return ZeroSumTransform (op .zerosum_axes )
747-
748-
749- @_logprob .register (ZeroSumNormalRV )
750- def zerosumnormal_logp (op , values , normal_dist , sigma , ** kwargs ):
751- (value ,) = values
752- shape = value .shape
753- _deg_free_shape = at .inc_subtensor (shape [at .as_tensor_variable (op .zerosum_axes )], - 1 )
754- _full_size = at .prod (shape )
755- _degrees_of_freedom = at .prod (_deg_free_shape )
756- zerosums = [
757- at .all (at .isclose (at .mean (value , axis = axis ), 0 , atol = 1e-9 )) for axis in op .zerosum_axes
758- ]
759- # out = at.sum(
760- # pm.logp(dist, value) * _degrees_of_freedom / _full_size,
761- # axis=op.zerosum_axes,
762- # )
763- # figure out how dimensionality should be handled for logp
764- # for now, we assume ZSN is a scalar distribut, which is not correct
765- out = pm .logp (normal_dist , value ) * _degrees_of_freedom / _full_size
766- return check_parameters (out , * zerosums , msg = "at.mean(value, axis=zerosum_axes) == 0" )
767-
768-
769588class TruncatedNormalRV (RandomVariable ):
770589 name = "truncated_normal"
771590 ndim_supp = 0
0 commit comments