Skip to content

Commit 6dbf60c

Browse files
jnetzel1ricardoV94
andcommitted
Use automatic logprob for LogitNormal to support icdf
Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
1 parent c1eda3a commit 6dbf60c

File tree

4 files changed

+26
-57
lines changed

4 files changed

+26
-57
lines changed

pymc/distributions/continuous.py

Lines changed: 19 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pytensor.tensor.random.utils import normalize_size_param
5454
from pytensor.tensor.variable import TensorConstant, TensorVariable
5555

56+
from pymc.distributions.custom import CustomDist
5657
from pymc.logprob.abstract import _logprob_helper
5758
from pymc.logprob.basic import TensorLike, icdf
5859
from pymc.pytensorf import normalize_rng_param
@@ -92,7 +93,7 @@ def polyagamma_cdf(*args, **kwargs):
9293
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable
9394
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
9495
from pymc.distributions.transforms import _default_transform
95-
from pymc.math import invlogit, logdiffexp, logit
96+
from pymc.math import invlogit, logdiffexp
9697

9798
__all__ = [
9899
"AsymmetricLaplace",
@@ -3603,28 +3604,7 @@ def icdf(value, mu, s):
36033604
)
36043605

36053606

3606-
class LogitNormalRV(SymbolicRandomVariable):
3607-
name = "logit_normal"
3608-
extended_signature = "[rng],[size],(),()->[rng],()"
3609-
_print_name = ("LogitNormal", "\\operatorname{LogitNormal}")
3610-
3611-
@classmethod
3612-
def rv_op(cls, mu, sigma, *, size=None, rng=None):
3613-
mu = pt.as_tensor(mu)
3614-
sigma = pt.as_tensor(sigma)
3615-
rng = normalize_rng_param(rng)
3616-
size = normalize_size_param(size)
3617-
3618-
next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
3619-
draws = pt.expit(normal_draws)
3620-
3621-
return cls(
3622-
inputs=[rng, size, mu, sigma],
3623-
outputs=[next_rng, draws],
3624-
)(rng, size, mu, sigma)
3625-
3626-
3627-
class LogitNormal(UnitContinuous):
3607+
class LogitNormal:
36283608
r"""
36293609
Logit-Normal distribution.
36303610
@@ -3672,37 +3652,26 @@ class LogitNormal(UnitContinuous):
36723652
Defaults to 1.
36733653
"""
36743654

3675-
rv_type = LogitNormalRV
3676-
rv_op = LogitNormalRV.rv_op
3655+
@staticmethod
3656+
def logitnormal_dist(mu, sigma, size):
3657+
return invlogit(Normal.dist(mu=mu, sigma=sigma, size=size))
36773658

3678-
@classmethod
3679-
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
3659+
def __new__(cls, name, mu=0, sigma=None, tau=None, **kwargs):
36803660
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
3681-
return super().dist([mu, sigma], **kwargs)
3682-
3683-
def support_point(rv, size, mu, sigma):
3684-
median, _ = pt.broadcast_arrays(invlogit(mu), sigma)
3685-
if not rv_size_is_none(size):
3686-
median = pt.full(size, median)
3687-
return median
3688-
3689-
def logp(value, mu, sigma):
3690-
tau, _ = get_tau_sigma(sigma=sigma)
3691-
3692-
res = pt.switch(
3693-
pt.or_(pt.le(value, 0), pt.ge(value, 1)),
3694-
-np.inf,
3695-
(
3696-
-0.5 * tau * (logit(value) - mu) ** 2
3697-
+ 0.5 * pt.log(tau / (2.0 * np.pi))
3698-
- pt.log(value * (1 - value))
3699-
),
3661+
return CustomDist(
3662+
name,
3663+
mu,
3664+
sigma,
3665+
dist=cls.logitnormal_dist,
3666+
class_name="LogitNormal",
3667+
**kwargs,
37003668
)
37013669

3702-
return check_parameters(
3703-
res,
3704-
tau > 0,
3705-
msg="tau > 0",
3670+
@classmethod
3671+
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
3672+
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
3673+
return CustomDist.dist(
3674+
mu, sigma, dist=cls.logitnormal_dist, class_name="LogitNormal", **kwargs
37063675
)
37073676

37083677

pymc/distributions/moments/means.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
HalfFlatRV,
6060
HalfStudentTRV,
6161
KumaraswamyRV,
62-
LogitNormalRV,
6362
MoyalRV,
6463
PolyaGammaRV,
6564
RiceRV,
@@ -290,11 +289,6 @@ def logistic_mean(op, rv, rng, size, mu, s):
290289
return maybe_resize(pt.broadcast_arrays(mu, s)[0], size)
291290

292291

293-
@_mean.register(LogitNormalRV)
294-
def logitnormal_mean(op, rv, rng, size, mu, sigma):
295-
raise UndefinedMomentException("The mean of the LogitNormal distribution is undefined")
296-
297-
298292
@_mean.register(LogNormalRV)
299293
def lognormal_mean(op, rv, rng, size, mu, sigma):
300294
return maybe_resize(pt.exp(mu + 0.5 * sigma**2), size)

tests/distributions/moments/test_means.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,5 +274,5 @@ def test_mean_equal_expected(dist, dist_params, expected):
274274
],
275275
)
276276
def test_no_mean(dist, dist_params):
277-
with pytest.raises(UndefinedMomentException):
277+
with pytest.raises((UndefinedMomentException, NotImplementedError)):
278278
mean(dist.dist(**dist_params))

tests/distributions/test_continuous.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,12 @@ def test_logitnormal(self):
872872
),
873873
decimal=select_by_precision(float64=6, float32=1),
874874
)
875+
check_icdf(
876+
pm.LogitNormal,
877+
{"mu": R, "sigma": Rplus},
878+
lambda q, mu, sigma: sp.expit(mu + sigma * st.norm.ppf(q)),
879+
decimal=select_by_precision(float64=12, float32=5),
880+
)
875881

876882
@pytest.mark.skipif(
877883
condition=(pytensor.config.floatX == "float32"),

0 commit comments

Comments
 (0)