Skip to content

Commit c1eda3a

Browse files
committed
Allow derived icdf for negative non-even power transforms
1 parent bfba9c3 commit c1eda3a

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

pymc/logprob/transforms.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class MeasurableTransform(MeasurableElemwise):
178178
Erf,
179179
Erfc,
180180
Erfcx,
181+
Sigmoid,
181182
)
182183

183184
# Cannot use `transform` as name because it would clash with the property added by
@@ -227,7 +228,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
227228
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
228229

229230

230-
MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf)
231+
MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid)
231232
MONOTONICALLY_DECREASING_OPS = (Erfc, Erfcx)
232233

233234

@@ -300,7 +301,18 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
300301
value = pt.switch(pt.lt(scale, 0), 1 - value, value)
301302
elif isinstance(op.scalar_op, Pow):
302303
if op.transform_elemwise.power < 0:
303-
raise NotImplementedError
304+
# Note: Negative even powers will be rejected below when inverting the transform
305+
# For the remaining negative powers the function is decreasing with a jump around 0
306+
# We adjust the value with the mass below zero.
307+
# For non-negative RVs with cdf(0)=0, it simplifies to 1 - value
308+
cdf_zero = pt.exp(_logcdf_helper(measurable_input, 0))
309+
# Use nan to not mask invalid values accidentally
310+
value = pt.switch((value >= 0) & (value <= 1), value, np.nan)
311+
value = pt.switch(
312+
(cdf_zero > 0) & (value < cdf_zero),
313+
cdf_zero - value,
314+
1 + cdf_zero - value,
315+
)
304316
else:
305317
raise NotImplementedError
306318

tests/logprob/test_transforms.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,7 @@ def test_reciprocal_rv_transform(self, numerator):
379379
x_vv = x_rv.clone()
380380
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
381381
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))
382-
383-
with pytest.raises(NotImplementedError):
384-
icdf(x_rv, x_vv)
382+
x_icdf_fn = pytensor.function([x_vv], icdf(x_rv, x_vv))
385383

386384
x_test_val = np.r_[-0.5, 1.5]
387385
np.testing.assert_allclose(
@@ -392,6 +390,10 @@ def test_reciprocal_rv_transform(self, numerator):
392390
x_logcdf_fn(x_test_val),
393391
sp.stats.invgamma(shape, scale=scale * numerator).logcdf(x_test_val),
394392
)
393+
np.testing.assert_allclose(
394+
x_icdf_fn(x_test_val),
395+
sp.stats.invgamma(shape, scale=scale * numerator).ppf(x_test_val),
396+
)
395397

396398
def test_reciprocal_real_rv_transform(self):
397399
# 1 / Cauchy(mu, sigma) = Cauchy(mu / (mu^2 + sigma ^2), sigma / (mu ^ 2, sigma ^ 2))
@@ -406,8 +408,10 @@ def test_reciprocal_real_rv_transform(self):
406408
logcdf(test_rv, test_value).eval(),
407409
sp.stats.cauchy(1 / 5, 2 / 5).logcdf(test_value),
408410
)
409-
with pytest.raises(NotImplementedError):
410-
icdf(test_rv, test_value)
411+
np.testing.assert_allclose(
412+
icdf(test_rv, test_value).eval(),
413+
sp.stats.cauchy(1 / 5, 2 / 5).ppf(test_value),
414+
)
411415

412416
def test_sqr_transform(self):
413417
# The square of a normal with unit variance is a noncentral chi-square with 1 df and nc = mean ** 2

0 commit comments

Comments
 (0)