@@ -178,6 +178,7 @@ class MeasurableTransform(MeasurableElemwise):
178178 Erf ,
179179 Erfc ,
180180 Erfcx ,
181+ Sigmoid ,
181182 )
182183
183184 # Cannot use `transform` as name because it would clash with the property added by
@@ -227,7 +228,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
227228 return pt .switch (pt .isnan (jacobian ), - np .inf , input_logprob + jacobian )
228229
229230
230- MONOTONICALLY_INCREASING_OPS = (Exp , Log , Add , Sinh , Tanh , ArcSinh , ArcCosh , ArcTanh , Erf )
231+ MONOTONICALLY_INCREASING_OPS = (Exp , Log , Add , Sinh , Tanh , ArcSinh , ArcCosh , ArcTanh , Erf , Sigmoid )
231232MONOTONICALLY_DECREASING_OPS = (Erfc , Erfcx )
232233
233234
@@ -300,7 +301,18 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
300301 value = pt .switch (pt .lt (scale , 0 ), 1 - value , value )
301302 elif isinstance (op .scalar_op , Pow ):
302303 if op .transform_elemwise .power < 0 :
303- raise NotImplementedError
304+ # Note: Negative even powers will be rejected below when inverting the transform
305+ # For the remaining negative powers the function is decreasing with a jump around 0
306+ # We adjust the value with the mass below zero.
307+ # For non-negative RVs with cdf(0)=0, it simplifies to 1 - value
308+ cdf_zero = pt .exp (_logcdf_helper (measurable_input , 0 ))
309+ # Use nan to not mask invalid values accidentally
310+ value = pt .switch ((value >= 0 ) & (value <= 1 ), value , np .nan )
311+ value = pt .switch (
312+ (cdf_zero > 0 ) & (value < cdf_zero ),
313+ cdf_zero - value ,
314+ 1 + cdf_zero - value ,
315+ )
304316 else :
305317 raise NotImplementedError
306318
0 commit comments