From c2fe0ae5bc13084425f459285d7226fb4ec72667 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 9 Nov 2025 22:28:14 +0100 Subject: [PATCH 1/4] Allow explicit rng for CustomDist that only require one --- pymc/distributions/custom.py | 12 +++++++- tests/distributions/test_custom.py | 44 ++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/custom.py b/pymc/distributions/custom.py index 5bf3c5995d..7829aedb8d 100644 --- a/pymc/distributions/custom.py +++ b/pymc/distributions/custom.py @@ -259,6 +259,7 @@ def rv_op( size=None, signature: str, class_name: str, + rng=None, ): size = normalize_size_param(size) # If it's NoneConst, just use that as the dummy @@ -270,7 +271,8 @@ def rv_op( dummy_rv = dist(*dummy_dist_params, dummy_size_param) dummy_params = [dummy_size_param, *dummy_dist_params] # RNGs are not passed as explicit inputs (because we usually don't know how many are needed) - # We retrieve them here. This will also raise if the user forgot to specify some update in a Scan Op + # We retrieve them here. This will also raise if the user forgot to specify some update in an InnerGraphOp (e.g., Scan) + # If the user passed an explicit rng we will respect that later when we instantiate the final rv_op dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) rv_type = type( @@ -357,6 +359,14 @@ def change_custom_dist_size(op, rv, new_size, expand): outputs=outputs, extended_signature=extended_signature, ) + if rng is not None: + # User passed an RNG, use that if the graph only required one, raise otherwise + if len(rngs) != 1: + raise ValueError( + f"CustomDist received an explicit rng but it actually requires {len(rngs)} rngs." + " Please modify your dist function to only use one rng, or don't pass an explicitly rng." + ) + rngs = (rng,) return rv_op(size, *dist_params, *rngs) @staticmethod diff --git a/tests/distributions/test_custom.py b/tests/distributions/test_custom.py index 7e05b2d02d..201594e037 100644 --- a/tests/distributions/test_custom.py +++ b/tests/distributions/test_custom.py @@ -708,3 +708,47 @@ def normal_shifted(mu, size): observed_logp.eval({latent_vv: latent_vv_test, observed_vv: observed_vv_test}), expected_logp, ) + + def test_explicit_rng(self): + def custom_dist(mu, size): + return Normal.dist(mu, size=size) + + x = CustomDist.dist(0, dist=custom_dist) + assert len(x.owner.op.rng_params(x.owner)) == 1 # Rng created by default + + explicit_rng = pt.random.type.random_generator_type("rng") + x_explicit = CustomDist.dist(0, dist=custom_dist, rng=explicit_rng) + [used_rng] = x_explicit.owner.op.rng_params(x_explicit.owner) + assert used_rng is explicit_rng + + # API for passing multiple explicit RNGs not supported + def custom_dist_multi_rng(mu, size): + return Normal.dist(mu, size=size) + Normal.dist(0, size=size) + + x = CustomDist.dist(0, dist=custom_dist_multi_rng) + assert len(x.owner.op.rng_params(x.owner)) == 2 + + with pytest.raises( + ValueError, + match="CustomDist received an explicit rng but it actually requires 2 rngs", + ): + CustomDist.dist( + 0, + dist=custom_dist_multi_rng, + rng=explicit_rng, + ) + + # But it can be done if the custom_dist uses only one RNG internally + def custom_dist_multi_rng_fixed(mu, size): + next_rng, x = Normal.dist(mu, size=size).owner.outputs + return x + Normal.dist(0, size=size, rng=next_rng) + + x = CustomDist.dist(0, dist=custom_dist_multi_rng_fixed) + assert len(x.owner.op.rng_params(x.owner)) == 1 + x_explicit = CustomDist.dist( + 0, + dist=custom_dist_multi_rng_fixed, + rng=explicit_rng, + ) + [used_rng] = x_explicit.owner.op.rng_params(x_explicit.owner) + assert used_rng is explicit_rng From 1f325bdc0ad35543b984c5ffa4fc8b7de66467a8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 8 Nov 2025 12:39:01 +0100 Subject: [PATCH 2/4] Allow derived icdf for negative non-even power transforms --- pymc/logprob/transforms.py | 16 ++++++++++++++-- tests/logprob/test_transforms.py | 14 +++++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8b5eac7b16..8d2bbacd26 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -178,6 +178,7 @@ class MeasurableTransform(MeasurableElemwise): Erf, Erfc, Erfcx, + Sigmoid, ) # Cannot use `transform` as name because it would clash with the property added by @@ -227,7 +228,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian) -MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf) +MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid) MONOTONICALLY_DECREASING_OPS = (Erfc, Erfcx) @@ -300,7 +301,18 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs) value = pt.switch(pt.lt(scale, 0), 1 - value, value) elif isinstance(op.scalar_op, Pow): if op.transform_elemwise.power < 0: - raise NotImplementedError + # Note: Negative even powers will be rejected below when inverting the transform + # For the remaining negative powers the function is decreasing with a jump around 0 + # We adjust the value with the mass below zero. + # For non-negative RVs with cdf(0)=0, it simplifies to 1 - value + cdf_zero = pt.exp(_logcdf_helper(measurable_input, 0)) + # Use nan to not mask invalid values accidentally + value = pt.switch((value >= 0) & (value <= 1), value, np.nan) + value = pt.switch( + (cdf_zero > 0) & (value < cdf_zero), + cdf_zero - value, + 1 + cdf_zero - value, + ) else: raise NotImplementedError diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 691c696e8f..c9aeaa8abf 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -379,9 +379,7 @@ def test_reciprocal_rv_transform(self, numerator): x_vv = x_rv.clone() x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv)) - - with pytest.raises(NotImplementedError): - icdf(x_rv, x_vv) + x_icdf_fn = pytensor.function([x_vv], icdf(x_rv, x_vv)) x_test_val = np.r_[-0.5, 1.5] np.testing.assert_allclose( @@ -392,6 +390,10 @@ def test_reciprocal_rv_transform(self, numerator): x_logcdf_fn(x_test_val), sp.stats.invgamma(shape, scale=scale * numerator).logcdf(x_test_val), ) + np.testing.assert_allclose( + x_icdf_fn(x_test_val), + sp.stats.invgamma(shape, scale=scale * numerator).ppf(x_test_val), + ) def test_reciprocal_real_rv_transform(self): # 1 / Cauchy(mu, sigma) = Cauchy(mu / (mu^2 + sigma ^2), sigma / (mu ^ 2, sigma ^ 2)) @@ -406,8 +408,10 @@ def test_reciprocal_real_rv_transform(self): logcdf(test_rv, test_value).eval(), sp.stats.cauchy(1 / 5, 2 / 5).logcdf(test_value), ) - with pytest.raises(NotImplementedError): - icdf(test_rv, test_value) + np.testing.assert_allclose( + icdf(test_rv, test_value).eval(), + sp.stats.cauchy(1 / 5, 2 / 5).ppf(test_value), + ) def test_sqr_transform(self): # The square of a normal with unit variance is a noncentral chi-square with 1 df and nc = mean ** 2 From f1b3d3cdcb6b10648a33fa7227d91ac633bd4f18 Mon Sep 17 00:00:00 2001 From: jnetzel1 Date: Fri, 26 Sep 2025 15:16:18 +0200 Subject: [PATCH 3/4] Use automatic logprob for LogitNormal to support icdf Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com> --- pymc/distributions/continuous.py | 69 +++++++---------------- pymc/distributions/moments/means.py | 6 -- tests/distributions/moments/test_means.py | 2 +- tests/distributions/test_continuous.py | 6 ++ 4 files changed, 26 insertions(+), 57 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index fd6e414605..886364ee55 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -53,6 +53,7 @@ from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.variable import TensorConstant, TensorVariable +from pymc.distributions.custom import CustomDist from pymc.logprob.abstract import _logprob_helper from pymc.logprob.basic import TensorLike, icdf from pymc.pytensorf import normalize_rng_param @@ -92,7 +93,7 @@ def polyagamma_cdf(*args, **kwargs): from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none from pymc.distributions.transforms import _default_transform -from pymc.math import invlogit, logdiffexp, logit +from pymc.math import invlogit, logdiffexp __all__ = [ "AsymmetricLaplace", @@ -3603,28 +3604,7 @@ def icdf(value, mu, s): ) -class LogitNormalRV(SymbolicRandomVariable): - name = "logit_normal" - extended_signature = "[rng],[size],(),()->[rng],()" - _print_name = ("LogitNormal", "\\operatorname{LogitNormal}") - - @classmethod - def rv_op(cls, mu, sigma, *, size=None, rng=None): - mu = pt.as_tensor(mu) - sigma = pt.as_tensor(sigma) - rng = normalize_rng_param(rng) - size = normalize_size_param(size) - - next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs - draws = pt.expit(normal_draws) - - return cls( - inputs=[rng, size, mu, sigma], - outputs=[next_rng, draws], - )(rng, size, mu, sigma) - - -class LogitNormal(UnitContinuous): +class LogitNormal: r""" Logit-Normal distribution. @@ -3672,37 +3652,26 @@ class LogitNormal(UnitContinuous): Defaults to 1. """ - rv_type = LogitNormalRV - rv_op = LogitNormalRV.rv_op + @staticmethod + def logitnormal_dist(mu, sigma, size): + return invlogit(Normal.dist(mu=mu, sigma=sigma, size=size)) - @classmethod - def dist(cls, mu=0, sigma=None, tau=None, **kwargs): + def __new__(cls, name, mu=0, sigma=None, tau=None, **kwargs): _, sigma = get_tau_sigma(tau=tau, sigma=sigma) - return super().dist([mu, sigma], **kwargs) - - def support_point(rv, size, mu, sigma): - median, _ = pt.broadcast_arrays(invlogit(mu), sigma) - if not rv_size_is_none(size): - median = pt.full(size, median) - return median - - def logp(value, mu, sigma): - tau, _ = get_tau_sigma(sigma=sigma) - - res = pt.switch( - pt.or_(pt.le(value, 0), pt.ge(value, 1)), - -np.inf, - ( - -0.5 * tau * (logit(value) - mu) ** 2 - + 0.5 * pt.log(tau / (2.0 * np.pi)) - - pt.log(value * (1 - value)) - ), + return CustomDist( + name, + mu, + sigma, + dist=cls.logitnormal_dist, + class_name="LogitNormal", + **kwargs, ) - return check_parameters( - res, - tau > 0, - msg="tau > 0", + @classmethod + def dist(cls, mu=0, sigma=None, tau=None, **kwargs): + _, sigma = get_tau_sigma(tau=tau, sigma=sigma) + return CustomDist.dist( + mu, sigma, dist=cls.logitnormal_dist, class_name="LogitNormal", **kwargs ) diff --git a/pymc/distributions/moments/means.py b/pymc/distributions/moments/means.py index 0e3129935e..34687a7ba2 100644 --- a/pymc/distributions/moments/means.py +++ b/pymc/distributions/moments/means.py @@ -59,7 +59,6 @@ HalfFlatRV, HalfStudentTRV, KumaraswamyRV, - LogitNormalRV, MoyalRV, PolyaGammaRV, RiceRV, @@ -290,11 +289,6 @@ def logistic_mean(op, rv, rng, size, mu, s): return maybe_resize(pt.broadcast_arrays(mu, s)[0], size) -@_mean.register(LogitNormalRV) -def logitnormal_mean(op, rv, rng, size, mu, sigma): - raise UndefinedMomentException("The mean of the LogitNormal distribution is undefined") - - @_mean.register(LogNormalRV) def lognormal_mean(op, rv, rng, size, mu, sigma): return maybe_resize(pt.exp(mu + 0.5 * sigma**2), size) diff --git a/tests/distributions/moments/test_means.py b/tests/distributions/moments/test_means.py index abfa9ee376..90d6766329 100644 --- a/tests/distributions/moments/test_means.py +++ b/tests/distributions/moments/test_means.py @@ -274,5 +274,5 @@ def test_mean_equal_expected(dist, dist_params, expected): ], ) def test_no_mean(dist, dist_params): - with pytest.raises(UndefinedMomentException): + with pytest.raises((UndefinedMomentException, NotImplementedError)): mean(dist.dist(**dist_params)) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 7209382666..e1e9b467d5 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -872,6 +872,12 @@ def test_logitnormal(self): ), decimal=select_by_precision(float64=6, float32=1), ) + check_icdf( + pm.LogitNormal, + {"mu": R, "sigma": Rplus}, + lambda q, mu, sigma: sp.expit(mu + sigma * st.norm.ppf(q)), + decimal=select_by_precision(float64=12, float32=5), + ) @pytest.mark.skipif( condition=(pytensor.config.floatX == "float32"), From eec79eeb2c149b0b4973fdd820a69e14b81b1d53 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 8 Nov 2025 12:43:20 +0100 Subject: [PATCH 4/4] Support InverseGamma icdf Co-authored-by: lucaseckes --- pymc/distributions/continuous.py | 3 +++ tests/distributions/test_continuous.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 886364ee55..204eef7c11 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2532,6 +2532,9 @@ def logcdf(value, alpha, beta): msg="alpha > 0, beta > 0", ) + def icdf(value, alpha, beta): + return icdf(1 / Gamma.dist(alpha, beta), value) + class ChiSquared: r""" diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index e1e9b467d5..a2d949f3d0 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -687,6 +687,13 @@ def test_inverse_gamma_logcdf(self): lambda value, alpha, beta: st.invgamma.logcdf(value, alpha, scale=beta), ) + def test_inverse_gamma_icdf(self): + check_icdf( + pm.InverseGamma, + {"alpha": Rplusbig, "beta": Rplusbig}, + lambda q, alpha, beta: st.invgamma.ppf(q, alpha, scale=beta), + ) + @pytest.mark.skipif( condition=(pytensor.config.floatX == "float32"), reason="Fails on float32 due to scaling issues",