From 81984edde565c2c75916f1291383271e371a5793 Mon Sep 17 00:00:00 2001 From: Oskar Laverny Date: Sat, 4 Oct 2025 22:02:21 +0200 Subject: [PATCH 1/6] Implement chainrules for beta_inc and beta_inc_inv and associated tests --- ext/SpecialFunctionsChainRulesCoreExt.jl | 431 +++++++++++++++++++++++ test/chainrules.jl | 127 ++++++- 2 files changed, 557 insertions(+), 1 deletion(-) diff --git a/ext/SpecialFunctionsChainRulesCoreExt.jl b/ext/SpecialFunctionsChainRulesCoreExt.jl index 169ee237..0ce0c2ec 100644 --- a/ext/SpecialFunctionsChainRulesCoreExt.jl +++ b/ext/SpecialFunctionsChainRulesCoreExt.jl @@ -300,4 +300,435 @@ function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number) return Ω, besselyx_pullback end + +# Note on incomplete beta derivatives implementation +# -------------------------------------------------- +# The rules for the regularized incomplete beta I_x(a,b) and its inverse are +# implemented using a direct translation of the original S-PLUS/MATLAB code by +# Boik & Robinson-Cox. See: +# Boik, R. J., & Robinson-Cox, J. F. (1998). +# Derivatives of the incomplete beta function with respect to its parameters. +# Computational Statistics & Data Analysis, 27(1), 85–106. +# The coefficient recurrences and derivative accumulation are ported verbatim +# (scalar form) from inc.beta.deriv.S/inbeder.m. + +@inline function _derconf_coeffs(n::Int, p::T, q::T, w::T) where {T<:AbstractFloat} + F = w * q / p + if n == 1 + t1 = 1 - inv(p + 1) + t2 = 1 - inv(q) + t3 = 1 - 2 / (p + 2) + t4 = 1 - 2 / q + an1 = t1 * t2 * F + an2 = -an1 / (p + 1) + an4 = t1 * F / q + bn1 = 1 - t3 * t4 * F + bn2 = t3 * t4 * F / (p + 2) + bn4 = -t3 * F / q + return an1, an2, an4, bn1, bn2, bn4 + end + t2 = F^2 + t3 = 2n - 2 + t5 = p * q + t7 = inv(t3 * q + t5) + t8 = t2 * t7 + t9 = n^2 + t10 = t9^2 + t11 = t2 * t10 + t12 = 4n - 2 + t13 = q^2 + t14 = t12 * t13 + t15 = p * t13 + t17 = inv(t14 + 2t15) + t19 = t9 * n + t20 = t19 * t2 + t22 = inv(p + (2n - 1)) + t23 = t20 * t22 + t24 = 2n - 1 + t27 = inv(t24 * q + t5) + t28 = t20 * t27 + t30 = t10 * n * t2 + t32 = n * t2 + t33 = 2n - 3 + t36 = inv(t33 * t13 + t15) + t37 = t32 * t36 + t38 = t9 * t2 + t39 = inv(t13) + t41 = t32 * t39 + t43 = (-8 + 4n) * n + t47 = inv(4 + t43 + (4n - 4 + p) * p) + t49 = t38 * t17 + t50 = t38 * t47 + t51 = t20 * t47 + t52 = inv(q) + t54 = t2 * t47 + t55 = t32 * t47 + t57 = inv(2p + (4n - 6)) + t59 = 4t8 - 3t11 * t17 - 4t23 - t28 - 4t30 * t27 + 9t37 - t38 * t39 + t41 + 4t11 * t47 - t49 + 24t50 - 16t51 - t2 * t52 + 4t54 - 16t55 - 53t38 * t57 + t62 = inv(p + (2n - 2)) + t63 = t32 * t62 + t65 = inv(2p + (4n - 2)) + t69 = t2 * inv(p + (2n - 3)) + t70 = t69 * t19 + t73 = inv(t3 * t13 + t15) + t74 = t11 * t73 + t76 = t10 * t9 * t2 + t79 = inv(t24 * t13 + t15) + t81 = t2 * t62 + t82 = 4 + t43 + t84 = 4n - 4 + t89 = inv(t82 * t13 + (t84 * t13 + t15) * p) + t91 = t20 * t36 + t92 = t11 * t27 + t96 = t20 * t89 + t97 = t20 * t7 + t98 = t12 * q + t100 = inv(t98 + 2t5) + t102 = 51t32 * t57 - 24t63 + 5t38 * t65 + 12t70 + 40t74 + 2t76 * t79 + 8t81 + 4t76 * t89 + 52t91 + 6t92 - 2t69 * t10 - 8t20 * t62 + 2t11 * t22 - 16t96 - 64t97 + t32 * t100 + t104 = t38 * t62 + t105 = t30 * t36 + t107 = 4n - 6 + t108 = t107 * q + t110 = inv(t108 + 2t5) + t113 = t38 * t73 + t116 = inv(t33 * q + t5) + t117 = t11 * t116 + t118 = t20 * t116 + t119 = t30 * t79 + t120 = t32 * t73 + t122 = t20 * t73 + t123 = t20 * t79 + t126 = 24t104 + 14t105 + t32 * t52 + 87t32 * t110 - 9t69 - 12t30 * t73 + 24t113 - 26t117 + 65t118 - 2t119 - 4t120 + 4t30 * t116 - 48t122 + 2t123 - 2t76 * t36 - 3t38 * t100 + t132 = inv(t82 * q + (t84 * q + t5) * p) + t133 = t20 * t132 + t135 = t38 * t89 + t136 = t11 * t89 + t137 = t30 * t89 + t138 = t11 * t132 + t139 = t107 * t13 + t141 = inv(t139 + 2t15) + t142 = t38 * t141 + t143 = t32 * t132 + t144 = t32 * t7 + t145 = t38 * t7 + t149 = t38 * t132 + t151 = t2 * t116 + t152 = -48t133 - 8t30 * t132 + 4t135 + 24t136 - 16t137 + 32t138 - 69t142 - 8t143 - 32t144 + 72t145 - t32 * t65 + 20t11 * t7 - 77t11 * t141 + 32t149 - 155t38 * t110 - 9t151 + an1 = t59 + t102 + t126 + t152 + # an2 (∂/∂p) + t155 = (4n - 4) * n + t156 = 1 + t155 + t161 = inv(t156 * t13 + (t14 + t15) * p) + t162 = t30 * t161 + t163 = -8 + 8n + t164 = t163 * n + t165 = 2 + t164 + t167 = -4 + 8n + t172 = inv(t165 * t13 + (t167 * t13 + 2t15) * p) + t175 = (-24 + 8n) * n + t179 = inv(18 + t175 + (-12 + 8n + 2p) * p) + t181 = t20 * t161 + t182 = t38 * t22 + t184 = (24 + t175) * n + t186 = (-24 + 12n) * n + t192 = inv(-8 + t184 + (12 + t186 + (-6 + 6n + p) * p) * p) + t198 = inv(t156 * q + (t98 + t5) * p) + t199 = t11 * t198 + t200 = t20 * t192 + t201 = -4t8 + 2t162 + 3t11 * t172 - 51t32 * t179 + 2t23 + 4t28 - 2t181 - 3t182 - 8t11 * t192 - 6t199 + 32t200 - 6t37 + t207 = inv(t165 * q + (t167 * q + 2t5) * p) + t210 = (-12 + 4n) * n + t211 = 9 + t210 + t216 = inv(t211 * t13 + (t139 + t15) * p) + t217 = t32 * t216 + t218 = -8 + t184 + t220 = 12 + t186 + t222 = -6 + 6n + t229 = inv(t218 * t13 + (t220 * t13 + (t222 * t13 + t15) * p) * p) + t230 = t11 * t229 + t231 = t20 * t216 + t232 = t69 * n + t233 = t30 * t216 + t234 = 18 + t175 + t236 = -12 + 8n + t241 = inv(t234 * t13 + (t236 * t13 + 2t15) * p) + t242 = t38 * t241 + t243 = 3t38 * t207 - 36t50 + 12t51 - 12t54 - 9t217 + 36t55 + 12t63 - 48t230 - 52t231 - 13t232 - 14t233 + 69t242 + t245 = t32 * t192 + t251 = inv(t234 * q + (t236 * q + 2t5) * p) + t256 = inv(1 + t155 + (4n - 2 + p) * p) + t257 = t20 * t256 + t258 = 32t245 - 2t70 - 10t74 - 6t81 - 22t91 - 4t92 + 60t96 + 16t97 - 6t104 - 87t32 * t251 - 2t105 + 4t257 + t267 = inv(t218 * q + (t220 * q + (t222 * q + t5) * p) * p) + t268 = t11 * t267 + t269 = t11 * t79 + t270 = t30 * t229 + t271 = t32 * t267 + t272 = 6t69 - 64t268 - 18t113 + 4t117 - 20t118 - t269 + 32t270 + 2t119 + 4t120 + 24t122 - 2t123 + 16t271 + t276 = t32 * t27 + t277 = t69 * t9 + t278 = t38 * t116 + t279 = t38 * t192 + t281 = 77t11 * t241 - t276 + 88t133 - 28t135 - 52t136 + 16t137 + 9t277 + 35t278 - 28t138 - 48t279 + 40t143 + 155t38 * t251 + t286 = inv(t211 * q + (t108 + t5) * p) + t287 = t20 * t286 + t288 = t2 * t192 + t292 = inv(9 + t210 + (4n - 6 + p) * p) + t293 = t2 * t292 + t294 = t2 * t286 + t295 = t20 * t267 + t296 = t2 * t132 + t297 = t32 * t89 + t299 = 24t144 - 36t145 - 96t149 - 65t287 + 6t151 - 8t288 + 9t293 + 9t294 + 96t295 - 4t296 + 4t297 - 4t30 * t286 + t304 = t11 * t286 + t305 = t32 * t116 + t308 = t38 * t267 + t309 = t11 * t36 + t311 = t38 * t79 + t315 = inv(2 + t164 + (-4 + 8n + 2p) * p) + t317 = 2t11 * t292 - t32 * t207 - 2t11 * t256 + 26t304 - 25t305 + 4t30 * t198 + 16t30 * t267 - 64t308 + 11t309 - 8t76 * t229 + t311 - 5t38 * t315 + t319 = t32 * t22 + t320 = t20 * t198 + t321 = t20 * t292 + t322 = t38 * t229 + t323 = t38 * t27 + t324 = t20 * t229 + t328 = t38 * t36 + t329 = t38 * t172 + t330 = t32 * t315 + t319 + t320 - 12t321 - 8t322 + t323 + 32t324 - 2t76 * t161 + 2t76 * t216 + 53t38 * t179 + 19t328 + t329 + an2 = t201 + t243 + t258 + t272 + t281 + t299 + t317 + t330 + # an4 + t521 = 16t8 - 8t28 + t41 - 3t49 + 20t74 + 65t91 + 4t92 - 48t96 - 16t97 + 4t105 + 72t113 - 4t117 + 24t118 + + 6t269 - 4t119 - 32t120 - 64t122 - t123 - t276 - 32t133 + t526 = t2 * t73 + t527 = t2 * t36 + t528 = 48t149 - 18t151 + 8t296 - 8t297 + 51t305 - 26t309 + 5t323 + t32 * t17 + 87t32 * t141 + 4t526 - 9t527 + an4 = t521 + 32t135 + 32t136 - 8t137 - t2 * t39 - 53t278 + 8t138 - 155t142 - 32t143 - 48t144 + 48t145 + t528 + + # bn1, bn2, bn4 + t544 = t9 * F + t546 = inv(p + 2n) + t548 = q * n + t550 = inv(t5 + 2t548) + t551 = t544 * t550 + t552 = t544 * t7 + t553 = n * F + t554 = t553 * t7 + t555 = t19 * F + t557 = F * t62 + t559 = t557 * n + bn1 = 1 - F + 2t544 * t546 - 2t551 - 4t552 + 2t554 + 2t555 * t7 - 2t557 - 2t557 * t9 + 4t559 - 2t555 * t550 + 2t553 * t52 + t563 = t553 * t550 + t564 = t553 * t132 + t567 = t544 * t132 + t568 = F * t47 + t572 = inv(4 * t9 + (4n + p) * p) + t574 = q * t9 + t578 = inv(4 * t574 + (4 * t548 + t5) * p) + t580 = t544 * t578 + t582 = t553 * t47 + bn2 = -t563 - 2t564 + 2t544 * t47 - 2t555 * t132 + 4t567 + 2t568 - 2t544 * t572 + 2t555 * t578 - t551 + 2t580 + t552 - t554 + t557 - t559 + t553 * t546 - 4t582 + bn4 = -F * t52 - 2t552 + 4t554 - 2(F * t7) + 2t551 + return an1, an2, an4, bn1, bn2, bn4 +end + +function _ibeta_grad_splus(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T=eps(T)*T(1e4)) where {T<:AbstractFloat} + if x <= zero(T) + return zero(T), zero(T), zero(T), zero(T) + elseif x >= one(T) + return one(T), zero(T), zero(T), zero(T) + end + # ∂I/∂x at original params + dI_dx = exp(muladd(a - one(T), log(x), muladd(b - one(T), log1p(-x), -logbeta(a, b)))) + # psi + lbet = logbeta(a, b) + pa = digamma(a); pa1 = trigamma(a) + pb = digamma(b); pb1 = trigamma(b) + pab = digamma(a + b); pab1 = trigamma(a + b) + # possibly swap + x1 = x; omx = one(T) - x; pp = a; qq = b + swapped = false + if x > a / (a + b) + swapped = true + x1 = one(T) - x + omx = x + pp, qq = b, a + pa, pb = pb, pa + pa1, pb1 = pb1, pa1 + end + w = x1 / omx + logx1 = log(x1); logomx = log(omx) + cc1 = muladd(pp, logx1, muladd(qq - one(T), logomx, -lbet - log(pp))) + c0 = exp(cc1) + cc2 = logx1 - inv(pp) - pa + pab + cc4 = logomx - pb + pab + # init recurrences + an1_1 = one(T); an1_p = zero(T); an1_q = zero(T) + an2_1 = one(T); an2_p = zero(T); an2_q = zero(T) + bn1_1 = one(T); bn1_p = zero(T); bn1_q = zero(T) + bn2_1 = zero(T); bn2_p = zero(T); bn2_q = zero(T) + I = zero(T); Ip = zero(T); Iq = zero(T) + prevI = T(NaN); prevIp = T(NaN); prevIq = T(NaN) + d = one(T); n = 0 + while (n < minapp) || ((d >= err) && (n < maxapp)) + n += 1 + a1, ap, aq, b1, bp, bq = _derconf_coeffs(n, pp, qq, w) + # forward recurrences + dan1 = a1 * an2_1 + b1 * an1_1 + dbn1 = a1 * bn2_1 + b1 * bn1_1 + danp = ap * an2_1 + a1 * an2_p + bp * an1_1 + b1 * an1_p + dbnp = ap * bn2_1 + a1 * bn2_p + bp * bn1_1 + b1 * bn1_p + danq = aq * an2_1 + a1 * an2_q + bq * an1_1 + b1 * an1_q + dbnq = aq * bn2_1 + a1 * bn2_q + bq * bn1_1 + b1 * bn1_q + # scale + Rn = dan1 + if abs(dbn1) > abs(dan1) + Rn = dbn1 + end + if Rn != 0 + invRn = inv(Rn) + an1_1 *= invRn; an1_p *= invRn; an1_q *= invRn + bn1_1 *= invRn; bn1_p *= invRn; bn1_q *= invRn + danp *= invRn; dbnp *= invRn; danq *= invRn; dbnq *= invRn + if abs(dbn1) > abs(dan1) + dan1 *= invRn; dbn1 = one(T) + else + dbn1 *= invRn; dan1 = one(T) + end + else + dbn1 = one(T); dan1 = one(T) + end + # approximant components + dr1 = dan1 / dbn1 + drp = (danp - dr1 * dbnp) / dbn1 + drq = (danq - dr1 * dbnq) / dbn1 + # shift n-1/n-2 + an2_1, an2_p, an2_q = an1_1, an1_p, an1_q + an1_1, an1_p, an1_q = dan1, danp, danq + bn2_1, bn2_p, bn2_q = bn1_1, bn1_p, bn1_q + bn1_1, bn1_p, bn1_q = dbn1, dbnp, dbnq + # nth approximant + pr = dr1 > 0 ? exp(cc1 + log(dr1)) : zero(T) + I = pr + Ip = muladd(pr, cc2, c0 * drp) + Iq = muladd(pr, cc4, c0 * drq) + # convergence + d1 = max(err, abs(I)); d2 = max(err, abs(Ip)); d4 = max(err, abs(Iq)) + dI = isfinite(prevI) ? abs(prevI - I) / d1 : one(T) + dP = isfinite(prevIp) ? abs(prevIp - Ip) / d2 : one(T) + dQ = isfinite(prevIq) ? abs(prevIq - Iq) / d4 : one(T) + d = max(dI, max(dP, dQ)) + prevI, prevIp, prevIq = I, Ip, Iq + end + if swapped + I = one(T) - I + Ip, Iq = -Iq, -Ip + end + return I, Ip, Iq, dI_dx +end + + + +# Incomplete beta: beta_inc(a,b,x) -> (p, q) with q=1-p +function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number, b::Number, x::Number) + # primal + p, q = beta_inc(a, b, x) + # derivatives + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) + _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) + Δp = dIa * T(Δa) + dIb * T(Δb) + dIx * T(Δx) + Δq = -Δp + Tout = typeof((p, q)) + return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq) +end + +function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Number) + p, q = beta_inc(a, b, x) + Ta = ChainRulesCore.ProjectTo(a) + Tb = ChainRulesCore.ProjectTo(b) + Tx = ChainRulesCore.ProjectTo(x) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) + _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) + function beta_inc_pullback(Δ) + Δp, Δq = Δ + s = T(Δp) - T(Δq) # because q = 1 - p + ā = Ta(s * dIa) + b̄ = Tb(s * dIb) + x̄ = Tx(s * dIx) + return ChainRulesCore.NoTangent(), ā, b̄, x̄ + end + return (p, q), beta_inc_pullback +end +function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Number, b::Number, x::Number, y::Number) + p, q = beta_inc(a, b, x, y) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) + _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) + Δp = dIa * T(Δa) + dIb * T(Δb) + dIx * (T(Δx) - T(Δy)) + Δq = -Δp + Tout = typeof((p, q)) + return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq) +end + +function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Number, y::Number) + p, q = beta_inc(a, b, x, y) + Ta = ChainRulesCore.ProjectTo(a) + Tb = ChainRulesCore.ProjectTo(b) + Tx = ChainRulesCore.ProjectTo(x) + Ty = ChainRulesCore.ProjectTo(y) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) + _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) + function beta_inc_pullback(Δ) + Δp, Δq = Δ + s = T(Δp) - T(Δq) + ā = Ta(s * dIa) + b̄ = Tb(s * dIb) + x̄ = Tx(s * dIx) + ȳ = Ty(-s * dIx) + return ChainRulesCore.NoTangent(), ā, b̄, x̄, ȳ + end + return (p, q), beta_inc_pullback +end + +# Inverse incomplete beta: beta_inc_inv(a,b,p) -> (x, 1-x) +function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Number, b::Number, p::Number) + x, y = beta_inc_inv(a, b, p) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p))) + # Implicit differentiation at solved x: I_x(a,b) = p + _, dIa, dIb, _ = _ibeta_grad_splus(T(a), T(b), T(x)) + # ∂I/∂x at solved x via stable log-space expression + dIx_acc = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) + inv_dIx = inv(dIx_acc) + dx_da = -dIa * inv_dIx + dx_db = -dIb * inv_dIx + dx_dp = inv_dIx + Δx = dx_da * T(Δa) + dx_db * T(Δb) + dx_dp * T(Δp) + Δy = -Δx + Tout = typeof((x, y)) + return (x, y), ChainRulesCore.Tangent{Tout}(Δx, Δy) +end + +function ChainRulesCore.rrule(::typeof(beta_inc_inv), a::Number, b::Number, p::Number) + x, y = beta_inc_inv(a, b, p) + Ta = ChainRulesCore.ProjectTo(a) + Tb = ChainRulesCore.ProjectTo(b) + Tp = ChainRulesCore.ProjectTo(p) + T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p))) + _, dIa, dIb, _ = _ibeta_grad_splus(T(a), T(b), T(x)) + # ∂I/∂x at solved x via stable log-space expression + dIx_acc = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) + inv_dIx = inv(dIx_acc) + dx_da = -dIa * inv_dIx + dx_db = -dIb * inv_dIx + dx_dp = inv_dIx + function beta_inc_inv_pullback(Δ) + Δx, Δy = Δ + s = T(Δx) - T(Δy) + ā = Ta(s * dx_da) + b̄ = Tb(s * dx_db) + p̄ = Tp(s * dx_dp) + return ChainRulesCore.NoTangent(), ā, b̄, p̄ + end + return (x, y), beta_inc_inv_pullback +end + end # module diff --git a/test/chainrules.jl b/test/chainrules.jl index 1754d591..4b3006fd 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -176,4 +176,129 @@ _, x̄ = back(1f0) @test x̄ isa Float32 end -end + + @testset "beta_inc and beta_inc_inv" begin + @testset "beta_inc and beta_inc_inv minimal (no-FD identities)" begin + a = 1.2 + b = 2.3 + x = 0.4 + # Direct derivative checks without FD: ∂I/∂x equals beta pdf + pdf = x^(a - 1) * (1 - x)^(b - 1) / beta(a, b) + _, Δx = frule((NoTangent(), 0.0, 0.0, 1.0), beta_inc, a, b, x) + @test isapprox(Δx[1], pdf; rtol=1e-12, atol=1e-12) + + # Symmetry check: ∂I/∂a(a,b,x) = -∂I/∂b(b,a,1-x) + _, Δa = frule((NoTangent(), 1.0, 0.0, 0.0), beta_inc, a, b, x) + _, Δb_sw = frule((NoTangent(), 0.0, 1.0, 0.0), beta_inc, b, a, 1 - x) + @test isapprox(Δa[1], -Δb_sw[1]; rtol=1e-10, atol=1e-12) + + # Composition identity f(g(p)) = p: forward-mode differential equals 1 for dp, 0 for da,db + p = first(beta_inc(a, b, x)) + x_inv, _ = beta_inc_inv(a, b, p) + # Check primal composition + p_roundtrip = first(beta_inc(a, b, x_inv)) + @test isapprox(p_roundtrip, p; rtol=1e-12, atol=1e-12) + # Forward through g then f: dp + _, Δx_inv_dp = frule((NoTangent(), 0.0, 0.0, 1.0), beta_inc_inv, a, b, p) + _, Δp_from_dp = frule((NoTangent(), 0.0, 0.0, Δx_inv_dp[1]), beta_inc, a, b, x_inv) + @test isapprox(Δp_from_dp[1], 1.0; rtol=1e-9, atol=1e-12) + # Forward da + _, Δx_inv_da = frule((NoTangent(), 1.0, 0.0, 0.0), beta_inc_inv, a, b, p) + _, Δp_from_da = frule((NoTangent(), 1.0, 0.0, Δx_inv_da[1]), beta_inc, a, b, x_inv) + @test isapprox(Δp_from_da[1], 0.0; rtol=1e-9, atol=1e-12) + # Forward db + _, Δx_inv_db = frule((NoTangent(), 0.0, 1.0, 0.0), beta_inc_inv, a, b, p) + _, Δp_from_db = frule((NoTangent(), 0.0, 1.0, Δx_inv_db[1]), beta_inc, a, b, x_inv) + @test isapprox(Δp_from_db[1], 0.0; rtol=1e-9, atol=1e-12) + + # Reverse-mode chain for composition: pullback through f then g + # Pullback of f at (a,b,x_inv) + _, pb_f = rrule(beta_inc, a, b, x_inv) + _, āf, b̄f, x̄f = pb_f((1.0, 0.0)) + # Pullback of g at (a,b,p) with cotangent x̄f for x + _, pb_g = rrule(beta_inc_inv, a, b, p) + _, āg, b̄g, p̄g = pb_g((x̄f, 0.0)) + ā_total = āf + āg + b̄_total = b̄f + b̄g + p̄_total = p̄g + @test isapprox(ā_total, 0.0; rtol=1e-10, atol=1e-12) + @test isapprox(b̄_total, 0.0; rtol=1e-10, atol=1e-12) + @test isapprox(p̄_total, 1.0; rtol=1e-9, atol=1e-12) + end + + @testset "incomplete beta: basic test_frule/test_rrule" begin + # Use a small, representative set of interior points (avoid endpoints for FD) + test_points = (0.2, 0.5, 0.8) + ab = (0.7, 2.5) + + # 3-argument beta_inc(a,b,x) + for a in ab, b in ab, x in test_points + 0.0 < x < 1.0 || continue + test_frule(beta_inc, a, b, x) + test_rrule(beta_inc, a, b, x) + end + + # Inverse beta: beta_inc_inv(a,b,p) + for a in ab, b in ab, p in test_points + 0.0 < p < 1.0 || continue + test_frule(beta_inc_inv, a, b, p) + test_rrule(beta_inc_inv, a, b, p) + end + + # Float32 promotion sanity (lightweight) + a32 = 1.5f0; b32 = 2.25f0; x32 = 0.3f0 + # Finite-difference checks for Float32 are noisier; use looser tolerances + test_frule(beta_inc, a32, b32, x32; rtol=5e-4, atol=1e-6) + test_rrule(beta_inc, a32, b32, x32; rtol=5e-4, atol=1e-6) + p32 = first(beta_inc(a32, b32, x32)) + test_frule(beta_inc_inv, a32, b32, p32; rtol=5e-4, atol=1e-6) + test_rrule(beta_inc_inv, a32, b32, p32; rtol=5e-4, atol=1e-6) + end + + @testset "4-arg beta_inc identities (y = 1 - x)" begin + test_points = (0.2, 0.5, 0.8) + ab = (0.7, 2.5) + + for a in ab, b in ab, x in test_points + 0.0 < x < 1.0 || continue + y = 1 - x + # Primal consistency: 4-arg matches 3-arg when y = 1 - x + p3, q3 = beta_inc(a, b, x) + p4, q4 = beta_inc(a, b, x, y) + @test isapprox(p4, p3; rtol=1e-12, atol=1e-12) + @test isapprox(q4, q3; rtol=1e-12, atol=1e-12) + + # Analytical pdf + pdf = x^(a - 1) * (1 - x)^(b - 1) / beta(a, b) + + # Constrained x-variation: dx = 1, dy = -1 => dp = 2 * pdf, dq = -dp + _, Δxy = frule((NoTangent(), 0.0, 0.0, 1.0, -1.0), beta_inc, a, b, x, y) + @test isapprox(Δxy[1], 2 * pdf; rtol=1e-11, atol=1e-12) + @test isapprox(Δxy[2], -Δxy[1]; rtol=1e-11, atol=1e-12) + + # Parameter derivatives should match 3-arg ones + _, Δa3 = frule((NoTangent(), 1.0, 0.0, 0.0), beta_inc, a, b, x) + _, Δb3 = frule((NoTangent(), 0.0, 1.0, 0.0), beta_inc, a, b, x) + _, Δa4 = frule((NoTangent(), 1.0, 0.0, 0.0, 0.0), beta_inc, a, b, x, y) + _, Δb4 = frule((NoTangent(), 0.0, 1.0, 0.0, 0.0), beta_inc, a, b, x, y) + @test isapprox(Δa4[1], Δa3[1]; rtol=1e-11, atol=1e-12) + @test isapprox(Δb4[1], Δb3[1]; rtol=1e-11, atol=1e-12) + + # Reverse-mode: compare pullbacks for 3-arg vs constrained 4-arg + _, pb3 = rrule(beta_inc, a, b, x) + _, ā3, b̄3, x̄3 = pb3((1.0, 0.0)) + _, pb4 = rrule(beta_inc, a, b, x, y) + _, ā4, b̄4, x̄4, ȳ4 = pb4((1.0, 0.0)) + @test isapprox(ā4, ā3; rtol=1e-11, atol=1e-12) + @test isapprox(b̄4, b̄3; rtol=1e-11, atol=1e-12) + # Unconstrained pullbacks should satisfy x̄4 ≈ x̄3 and ȳ4 ≈ -x̄3 + @test isapprox(x̄4, x̄3; rtol=1e-11, atol=1e-12) + @test isapprox(ȳ4, -x̄3; rtol=1e-11, atol=1e-12) + # Effective pullback along the constraint y = 1 - x equals 2*x̄3 + x̄_eff = x̄4 - ȳ4 + @test isapprox(x̄_eff, 2 * x̄3; rtol=1e-11, atol=1e-12) + end + end + + end +end \ No newline at end of file From 5a4fedd60867430ee90fc4568eda1773a7153e4c Mon Sep 17 00:00:00 2001 From: Oskar Laverny Date: Sat, 4 Oct 2025 22:27:13 +0200 Subject: [PATCH 2/6] Try to fix on older julia --- ext/SpecialFunctionsChainRulesCoreExt.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ext/SpecialFunctionsChainRulesCoreExt.jl b/ext/SpecialFunctionsChainRulesCoreExt.jl index 0ce0c2ec..289b859c 100644 --- a/ext/SpecialFunctionsChainRulesCoreExt.jl +++ b/ext/SpecialFunctionsChainRulesCoreExt.jl @@ -636,7 +636,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number, # derivatives T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) - Δp = dIa * T(Δa) + dIb * T(Δb) + dIx * T(Δx) + Δp = dIa * convert(T, Δa) + dIb * convert(T, Δb) + dIx * convert(T, Δx) Δq = -Δp Tout = typeof((p, q)) return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq) @@ -651,7 +651,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Numbe _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) function beta_inc_pullback(Δ) Δp, Δq = Δ - s = T(Δp) - T(Δq) # because q = 1 - p + s = Δp - Δq # because q = 1 - p ā = Ta(s * dIa) b̄ = Tb(s * dIb) x̄ = Tx(s * dIx) @@ -663,7 +663,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Nu p, q = beta_inc(a, b, x, y) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) - Δp = dIa * T(Δa) + dIb * T(Δb) + dIx * (T(Δx) - T(Δy)) + Δp = dIa * convert(T, Δa) + dIb * convert(T, Δb) + dIx * (convert(T, Δx) - convert(T, Δy)) Δq = -Δp Tout = typeof((p, q)) return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq) @@ -679,7 +679,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Numbe _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) function beta_inc_pullback(Δ) Δp, Δq = Δ - s = T(Δp) - T(Δq) + s = Δp - Δq ā = Ta(s * dIa) b̄ = Tb(s * dIb) x̄ = Tx(s * dIx) @@ -701,7 +701,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Num dx_da = -dIa * inv_dIx dx_db = -dIb * inv_dIx dx_dp = inv_dIx - Δx = dx_da * T(Δa) + dx_db * T(Δb) + dx_dp * T(Δp) + Δx = dx_da * convert(T, Δa) + dx_db * convert(T, Δb) + dx_dp * convert(T, Δp) Δy = -Δx Tout = typeof((x, y)) return (x, y), ChainRulesCore.Tangent{Tout}(Δx, Δy) @@ -722,7 +722,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc_inv), a::Number, b::Number, p::N dx_dp = inv_dIx function beta_inc_inv_pullback(Δ) Δx, Δy = Δ - s = T(Δx) - T(Δy) + s = Δx - Δy ā = Ta(s * dx_da) b̄ = Tb(s * dx_db) p̄ = Tp(s * dx_dp) From c96961d092509f1c12d879b3f425aa749a33f835 Mon Sep 17 00:00:00 2001 From: Oskar Laverny Date: Sun, 5 Oct 2025 14:47:51 +0200 Subject: [PATCH 3/6] Change implementaiton to arzwa/IncBetaDer.jl --- ext/SpecialFunctionsChainRulesCoreExt.jl | 534 +++++++++-------------- 1 file changed, 200 insertions(+), 334 deletions(-) diff --git a/ext/SpecialFunctionsChainRulesCoreExt.jl b/ext/SpecialFunctionsChainRulesCoreExt.jl index 289b859c..e1b75eba 100644 --- a/ext/SpecialFunctionsChainRulesCoreExt.jl +++ b/ext/SpecialFunctionsChainRulesCoreExt.jl @@ -301,330 +301,180 @@ function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number) end -# Note on incomplete beta derivatives implementation -# -------------------------------------------------- -# The rules for the regularized incomplete beta I_x(a,b) and its inverse are -# implemented using a direct translation of the original S-PLUS/MATLAB code by -# Boik & Robinson-Cox. See: -# Boik, R. J., & Robinson-Cox, J. F. (1998). -# Derivatives of the incomplete beta function with respect to its parameters. -# Computational Statistics & Data Analysis, 27(1), 85–106. -# The coefficient recurrences and derivative accumulation are ported verbatim -# (scalar form) from inc.beta.deriv.S/inbeder.m. - -@inline function _derconf_coeffs(n::Int, p::T, q::T, w::T) where {T<:AbstractFloat} - F = w * q / p - if n == 1 - t1 = 1 - inv(p + 1) - t2 = 1 - inv(q) - t3 = 1 - 2 / (p + 2) - t4 = 1 - 2 / q - an1 = t1 * t2 * F - an2 = -an1 / (p + 1) - an4 = t1 * F / q - bn1 = 1 - t3 * t4 * F - bn2 = t3 * t4 * F / (p + 2) - bn4 = -t3 * F / q - return an1, an2, an4, bn1, bn2, bn4 - end - t2 = F^2 - t3 = 2n - 2 - t5 = p * q - t7 = inv(t3 * q + t5) - t8 = t2 * t7 - t9 = n^2 - t10 = t9^2 - t11 = t2 * t10 - t12 = 4n - 2 - t13 = q^2 - t14 = t12 * t13 - t15 = p * t13 - t17 = inv(t14 + 2t15) - t19 = t9 * n - t20 = t19 * t2 - t22 = inv(p + (2n - 1)) - t23 = t20 * t22 - t24 = 2n - 1 - t27 = inv(t24 * q + t5) - t28 = t20 * t27 - t30 = t10 * n * t2 - t32 = n * t2 - t33 = 2n - 3 - t36 = inv(t33 * t13 + t15) - t37 = t32 * t36 - t38 = t9 * t2 - t39 = inv(t13) - t41 = t32 * t39 - t43 = (-8 + 4n) * n - t47 = inv(4 + t43 + (4n - 4 + p) * p) - t49 = t38 * t17 - t50 = t38 * t47 - t51 = t20 * t47 - t52 = inv(q) - t54 = t2 * t47 - t55 = t32 * t47 - t57 = inv(2p + (4n - 6)) - t59 = 4t8 - 3t11 * t17 - 4t23 - t28 - 4t30 * t27 + 9t37 - t38 * t39 + t41 + 4t11 * t47 - t49 + 24t50 - 16t51 - t2 * t52 + 4t54 - 16t55 - 53t38 * t57 - t62 = inv(p + (2n - 2)) - t63 = t32 * t62 - t65 = inv(2p + (4n - 2)) - t69 = t2 * inv(p + (2n - 3)) - t70 = t69 * t19 - t73 = inv(t3 * t13 + t15) - t74 = t11 * t73 - t76 = t10 * t9 * t2 - t79 = inv(t24 * t13 + t15) - t81 = t2 * t62 - t82 = 4 + t43 - t84 = 4n - 4 - t89 = inv(t82 * t13 + (t84 * t13 + t15) * p) - t91 = t20 * t36 - t92 = t11 * t27 - t96 = t20 * t89 - t97 = t20 * t7 - t98 = t12 * q - t100 = inv(t98 + 2t5) - t102 = 51t32 * t57 - 24t63 + 5t38 * t65 + 12t70 + 40t74 + 2t76 * t79 + 8t81 + 4t76 * t89 + 52t91 + 6t92 - 2t69 * t10 - 8t20 * t62 + 2t11 * t22 - 16t96 - 64t97 + t32 * t100 - t104 = t38 * t62 - t105 = t30 * t36 - t107 = 4n - 6 - t108 = t107 * q - t110 = inv(t108 + 2t5) - t113 = t38 * t73 - t116 = inv(t33 * q + t5) - t117 = t11 * t116 - t118 = t20 * t116 - t119 = t30 * t79 - t120 = t32 * t73 - t122 = t20 * t73 - t123 = t20 * t79 - t126 = 24t104 + 14t105 + t32 * t52 + 87t32 * t110 - 9t69 - 12t30 * t73 + 24t113 - 26t117 + 65t118 - 2t119 - 4t120 + 4t30 * t116 - 48t122 + 2t123 - 2t76 * t36 - 3t38 * t100 - t132 = inv(t82 * q + (t84 * q + t5) * p) - t133 = t20 * t132 - t135 = t38 * t89 - t136 = t11 * t89 - t137 = t30 * t89 - t138 = t11 * t132 - t139 = t107 * t13 - t141 = inv(t139 + 2t15) - t142 = t38 * t141 - t143 = t32 * t132 - t144 = t32 * t7 - t145 = t38 * t7 - t149 = t38 * t132 - t151 = t2 * t116 - t152 = -48t133 - 8t30 * t132 + 4t135 + 24t136 - 16t137 + 32t138 - 69t142 - 8t143 - 32t144 + 72t145 - t32 * t65 + 20t11 * t7 - 77t11 * t141 + 32t149 - 155t38 * t110 - 9t151 - an1 = t59 + t102 + t126 + t152 - # an2 (∂/∂p) - t155 = (4n - 4) * n - t156 = 1 + t155 - t161 = inv(t156 * t13 + (t14 + t15) * p) - t162 = t30 * t161 - t163 = -8 + 8n - t164 = t163 * n - t165 = 2 + t164 - t167 = -4 + 8n - t172 = inv(t165 * t13 + (t167 * t13 + 2t15) * p) - t175 = (-24 + 8n) * n - t179 = inv(18 + t175 + (-12 + 8n + 2p) * p) - t181 = t20 * t161 - t182 = t38 * t22 - t184 = (24 + t175) * n - t186 = (-24 + 12n) * n - t192 = inv(-8 + t184 + (12 + t186 + (-6 + 6n + p) * p) * p) - t198 = inv(t156 * q + (t98 + t5) * p) - t199 = t11 * t198 - t200 = t20 * t192 - t201 = -4t8 + 2t162 + 3t11 * t172 - 51t32 * t179 + 2t23 + 4t28 - 2t181 - 3t182 - 8t11 * t192 - 6t199 + 32t200 - 6t37 - t207 = inv(t165 * q + (t167 * q + 2t5) * p) - t210 = (-12 + 4n) * n - t211 = 9 + t210 - t216 = inv(t211 * t13 + (t139 + t15) * p) - t217 = t32 * t216 - t218 = -8 + t184 - t220 = 12 + t186 - t222 = -6 + 6n - t229 = inv(t218 * t13 + (t220 * t13 + (t222 * t13 + t15) * p) * p) - t230 = t11 * t229 - t231 = t20 * t216 - t232 = t69 * n - t233 = t30 * t216 - t234 = 18 + t175 - t236 = -12 + 8n - t241 = inv(t234 * t13 + (t236 * t13 + 2t15) * p) - t242 = t38 * t241 - t243 = 3t38 * t207 - 36t50 + 12t51 - 12t54 - 9t217 + 36t55 + 12t63 - 48t230 - 52t231 - 13t232 - 14t233 + 69t242 - t245 = t32 * t192 - t251 = inv(t234 * q + (t236 * q + 2t5) * p) - t256 = inv(1 + t155 + (4n - 2 + p) * p) - t257 = t20 * t256 - t258 = 32t245 - 2t70 - 10t74 - 6t81 - 22t91 - 4t92 + 60t96 + 16t97 - 6t104 - 87t32 * t251 - 2t105 + 4t257 - t267 = inv(t218 * q + (t220 * q + (t222 * q + t5) * p) * p) - t268 = t11 * t267 - t269 = t11 * t79 - t270 = t30 * t229 - t271 = t32 * t267 - t272 = 6t69 - 64t268 - 18t113 + 4t117 - 20t118 - t269 + 32t270 + 2t119 + 4t120 + 24t122 - 2t123 + 16t271 - t276 = t32 * t27 - t277 = t69 * t9 - t278 = t38 * t116 - t279 = t38 * t192 - t281 = 77t11 * t241 - t276 + 88t133 - 28t135 - 52t136 + 16t137 + 9t277 + 35t278 - 28t138 - 48t279 + 40t143 + 155t38 * t251 - t286 = inv(t211 * q + (t108 + t5) * p) - t287 = t20 * t286 - t288 = t2 * t192 - t292 = inv(9 + t210 + (4n - 6 + p) * p) - t293 = t2 * t292 - t294 = t2 * t286 - t295 = t20 * t267 - t296 = t2 * t132 - t297 = t32 * t89 - t299 = 24t144 - 36t145 - 96t149 - 65t287 + 6t151 - 8t288 + 9t293 + 9t294 + 96t295 - 4t296 + 4t297 - 4t30 * t286 - t304 = t11 * t286 - t305 = t32 * t116 - t308 = t38 * t267 - t309 = t11 * t36 - t311 = t38 * t79 - t315 = inv(2 + t164 + (-4 + 8n + 2p) * p) - t317 = 2t11 * t292 - t32 * t207 - 2t11 * t256 + 26t304 - 25t305 + 4t30 * t198 + 16t30 * t267 - 64t308 + 11t309 - 8t76 * t229 + t311 - 5t38 * t315 - t319 = t32 * t22 - t320 = t20 * t198 - t321 = t20 * t292 - t322 = t38 * t229 - t323 = t38 * t27 - t324 = t20 * t229 - t328 = t38 * t36 - t329 = t38 * t172 - t330 = t32 * t315 + t319 + t320 - 12t321 - 8t322 + t323 + 32t324 - 2t76 * t161 + 2t76 * t216 + 53t38 * t179 + 19t328 + t329 - an2 = t201 + t243 + t258 + t272 + t281 + t299 + t317 + t330 - # an4 - t521 = 16t8 - 8t28 + t41 - 3t49 + 20t74 + 65t91 + 4t92 - 48t96 - 16t97 + 4t105 + 72t113 - 4t117 + 24t118 + - 6t269 - 4t119 - 32t120 - 64t122 - t123 - t276 - 32t133 - t526 = t2 * t73 - t527 = t2 * t36 - t528 = 48t149 - 18t151 + 8t296 - 8t297 + 51t305 - 26t309 + 5t323 + t32 * t17 + 87t32 * t141 + 4t526 - 9t527 - an4 = t521 + 32t135 + 32t136 - 8t137 - t2 * t39 - 53t278 + 8t138 - 155t142 - 32t143 - 48t144 + 48t145 + t528 - - # bn1, bn2, bn4 - t544 = t9 * F - t546 = inv(p + 2n) - t548 = q * n - t550 = inv(t5 + 2t548) - t551 = t544 * t550 - t552 = t544 * t7 - t553 = n * F - t554 = t553 * t7 - t555 = t19 * F - t557 = F * t62 - t559 = t557 * n - bn1 = 1 - F + 2t544 * t546 - 2t551 - 4t552 + 2t554 + 2t555 * t7 - 2t557 - 2t557 * t9 + 4t559 - 2t555 * t550 + 2t553 * t52 - t563 = t553 * t550 - t564 = t553 * t132 - t567 = t544 * t132 - t568 = F * t47 - t572 = inv(4 * t9 + (4n + p) * p) - t574 = q * t9 - t578 = inv(4 * t574 + (4 * t548 + t5) * p) - t580 = t544 * t578 - t582 = t553 * t47 - bn2 = -t563 - 2t564 + 2t544 * t47 - 2t555 * t132 + 4t567 + 2t568 - 2t544 * t572 + 2t555 * t578 - t551 + 2t580 + t552 - t554 + t557 - t559 + t553 * t546 - 4t582 - bn4 = -F * t52 - 2t552 + 4t554 - 2(F * t7) + 2t551 - return an1, an2, an4, bn1, bn2, bn4 -end -function _ibeta_grad_splus(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T=eps(T)*T(1e4)) where {T<:AbstractFloat} - if x <= zero(T) - return zero(T), zero(T), zero(T), zero(T) - elseif x >= one(T) - return one(T), zero(T), zero(T), zero(T) +## Incomplete beta derivatives via Boik & Robinson-Cox +# +# Reference +# R. J. Boik and J. F. Robinson-Cox (1999). +# "Derivatives of the incomplete beta function." +# Journal of Statistical Software, 3(1). +# URL: https://www.jstatsoft.org/article/view/v003i01 +# +# The following implementation computes the regularized incomplete beta +# I_x(a,b) together with its partial derivatives with respect to a, b, and x +# using a continued-fraction representation of ₂F₁ and differentiating through it. +# This is an independent implementation adapted from https://github.com/arzwa/IncBetaDer.jl. + +# Generic-typed version for high-precision evaluation +function _beta_inc_grad_boik(a::T, b::T, x::T, + maxapp::Int=200, minapp::Int=3, ϵ::T=convert(T, 1e-12)) where {T<:AbstractFloat} + oneT = one(T); zeroT = zero(T) + if x == oneT + return oneT, zeroT, zeroT, zeroT + elseif x == zeroT + return zeroT, zeroT, zeroT, zeroT end - # ∂I/∂x at original params - dI_dx = exp(muladd(a - one(T), log(x), muladd(b - one(T), log1p(-x), -logbeta(a, b)))) - # psi - lbet = logbeta(a, b) - pa = digamma(a); pa1 = trigamma(a) - pb = digamma(b); pb1 = trigamma(b) - pab = digamma(a + b); pab1 = trigamma(a + b) - # possibly swap - x1 = x; omx = one(T) - x; pp = a; qq = b - swapped = false + dx = exp((a - oneT) * log(x) + (b - oneT) * log1p(-x) - logbeta(a,b)) + # swap tails if necessary + p = a; q = b; x₀ = x; swap = false if x > a / (a + b) - swapped = true - x1 = one(T) - x - omx = x - pp, qq = b, a - pa, pb = pb, pa - pa1, pb1 = pb1, pa1 + x₀ = oneT - x + p = b + q = a + swap = true + end + Kfun(x::T, p::T, q::T) = exp(p * log(x) + (q - oneT) * log1p(-x) - log(p) - logbeta(p, q)) + ffun(x::T, p::T, q::T) = q*x/(p*(oneT - x)) + a1fun(p::T, q::T, f::T) = p*f*(q - oneT)/(q*(p + oneT)) + anfun(p::T, q::T, f::T, n::Int) = n == 1 ? a1fun(p, q, f) : + p^2 * f^2 * (T(n) - oneT) * (p + q + T(n) - T(2)) * (p + T(n) - oneT) * (q - T(n)) / + (q^2 * (p + T(2n) - T(3)) * (p + T(2n) - T(2))^2 * (p + T(2n) - oneT)) + function bnfun(p::T, q::T, f::T, n::Int) + x = T(2)*(p*f + T(2)*q)*T(n)^2 + T(2)*(p*f + T(2)*q)*(p - oneT)*T(n) + p*q*(p - T(2) - p*f) + y = (q * (p + T(2n) - T(2)) * (p + T(2n))) + x/y + end + dK_dp(x::T, p::T, q::T, K::T, ψpq::T, ψp::T) = K*(log(x) - inv(p) + ψpq - ψp) + dK_dq(x::T, p::T, q::T, K::T, ψpq::T, ψq::T) = K*(log1p(-x) + ψpq - ψq) + function dK_dpdq(x::T, p::T, q::T) + ψ = digamma(p+q) + Kf = Kfun(x, p, q) + dKdp = dK_dp(x, p, q, Kf, ψ, digamma(p)) + dKdq = dK_dq(x, p, q, Kf, ψ, digamma(q)) + dKdp, dKdq end - w = x1 / omx - logx1 = log(x1); logomx = log(omx) - cc1 = muladd(pp, logx1, muladd(qq - one(T), logomx, -lbet - log(pp))) - c0 = exp(cc1) - cc2 = logx1 - inv(pp) - pa + pab - cc4 = logomx - pb + pab - # init recurrences - an1_1 = one(T); an1_p = zero(T); an1_q = zero(T) - an2_1 = one(T); an2_p = zero(T); an2_q = zero(T) - bn1_1 = one(T); bn1_p = zero(T); bn1_q = zero(T) - bn2_1 = zero(T); bn2_p = zero(T); bn2_q = zero(T) - I = zero(T); Ip = zero(T); Iq = zero(T) - prevI = T(NaN); prevIp = T(NaN); prevIq = T(NaN) - d = one(T); n = 0 - while (n < minapp) || ((d >= err) && (n < maxapp)) - n += 1 - a1, ap, aq, b1, bp, bq = _derconf_coeffs(n, pp, qq, w) - # forward recurrences - dan1 = a1 * an2_1 + b1 * an1_1 - dbn1 = a1 * bn2_1 + b1 * bn1_1 - danp = ap * an2_1 + a1 * an2_p + bp * an1_1 + b1 * an1_p - dbnp = ap * bn2_1 + a1 * bn2_p + bp * bn1_1 + b1 * bn1_p - danq = aq * an2_1 + a1 * an2_q + bq * an1_1 + b1 * an1_q - dbnq = aq * bn2_1 + a1 * bn2_q + bq * bn1_1 + b1 * bn1_q - # scale - Rn = dan1 - if abs(dbn1) > abs(dan1) - Rn = dbn1 + # a_n derivatives via log-derivative + da1_dp(p::T, q::T, f::T) = -a1fun(p, q, f) / (p + oneT) + function dan_dp(p::T, q::T, f::T, n::Int) + if n == 1 + return da1_dp(p, q, f) end - if Rn != 0 - invRn = inv(Rn) - an1_1 *= invRn; an1_p *= invRn; an1_q *= invRn - bn1_1 *= invRn; bn1_p *= invRn; bn1_q *= invRn - danp *= invRn; dbnp *= invRn; danq *= invRn; dbnq *= invRn - if abs(dbn1) > abs(dan1) - dan1 *= invRn; dbn1 = one(T) - else - dbn1 *= invRn; dan1 = one(T) + an = anfun(p, q, f, n) + dlog = inv(p + q + T(n) - T(2)) + inv(p + T(n) - oneT) - inv(p + T(2n) - T(3)) - T(2) * inv(p + T(2n) - T(2)) - inv(p + T(2n) - oneT) + return an * dlog + end + da1_dq(p::T, q::T, f::T) = a1fun(p, q, f) / (q - oneT) + function dan_dq(p::T, q::T, f::T, n::Int) + if n == 1 + return da1_dq(p, q, f) + end + an = anfun(p, q, f, n) + dlog = inv(p + q + T(n) - T(2)) + inv(q - T(n)) + return an * dlog + end + # b_n derivatives via quotient rule, accounting for f_p=-f/p, f_q=f/q which cancel in N + function dbn_dp(p::T, q::T, f::T, n::Int) + g = p * f + T(2) * q + A = T(2) * T(n)^2 + T(2) * (p - oneT) * T(n) + N1 = g * A + N2 = p * q * (p - T(2) - p * f) + N = N1 + N2 + D = q * (p + T(2n) - T(2)) * (p + T(2n)) + dN1_dp = T(2) * T(n) * g + dN2_dp = q * (T(2) * p - T(2)) - p * q * f + dN_dp = dN1_dp + dN2_dp + dD_dp = q * (T(2) * p + T(4) * T(n) - T(2)) + return (dN_dp * D - N * dD_dp) / (D^2) + end + function dbn_dq(p::T, q::T, f::T, n::Int) + g = p * f + T(2) * q + A = T(2) * T(n)^2 + T(2) * (p - oneT) * T(n) + N1 = g * A + N2 = p * q * (p - T(2) - p * f) + N = N1 + N2 + D = q * (p + T(2n) - T(2)) * (p + T(2n)) + g_q = p * (f / q) + T(2) + dN1_dq = g_q * A + dN2_dq = p * (p - T(2) - p * f) - p^2 * f + dN_dq = dN1_dq + dN2_dq + dD_dq = (p + T(2n) - T(2)) * (p + T(2n)) + return (dN_dq * D - N * dD_dq) / (D^2) + end + _nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) = begin + an = anfun(p, q, f, n) + bn = bnfun(p, q, f, n) + An = an*App + bn*Ap + Bn = an*Bpp + bn*Bp + An, Bn, an, bn + end + _dnextapp(an::T, bn::T, dan::T, dbn::T, Xpp::T, Xp::T, dXpp::T, dXp::T) = dan * Xpp + an * dXpp + dbn * Xp + bn * dXp + + # compute once + K = Kfun(x₀, p, q) + dK_dp_val, dK_dq_val = dK_dpdq(x₀, p, q) + f = ffun(x₀, p, q) + App = oneT; Ap = oneT; Bpp = zeroT; Bp = oneT + dApp_dp = zeroT; dBpp_dp = zeroT; dAp_dp = zeroT; dBp_dp = zeroT + dApp_dq = zeroT; dBpp_dq = zeroT; dAp_dq = zeroT; dBp_dq = zeroT + dI_dp = T(NaN); dI_dq = T(NaN); Ixpq = T(NaN); Ixpqn = T(NaN); dI_dp_prev = T(NaN); dI_dq_prev = T(NaN) + for n=1:maxapp + An, Bn, an, bn = _nextapp(f, p, q, n, App, Ap, Bpp, Bp) + dan = dan_dp(p, q, f, n); dbn = dbn_dp(p, q, f, n) + dAn_dp = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dp, dAp_dp) + dBn_dp = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dp, dBp_dp) + dan = dan_dq(p, q, f, n); dbn = dbn_dq(p, q, f, n) + dAn_dq = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dq, dAp_dq) + dBn_dq = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dq, dBp_dq) + # normalize states to control growth/underflow (scale-invariant) + s = maximum((abs(An), abs(Bn), abs(Ap), abs(Bp), abs(App), abs(Bpp))) + if isfinite(s) && s > zeroT + invs = inv(s) + An *= invs; Bn *= invs + Ap *= invs; Bp *= invs + App *= invs; Bpp *= invs + dAn_dp *= invs; dBn_dp *= invs + dAn_dq *= invs; dBn_dq *= invs + dAp_dp *= invs; dBp_dp *= invs + dApp_dp *= invs; dBpp_dp *= invs + dAp_dq *= invs; dBp_dq *= invs + dApp_dq *= invs; dBpp_dq *= invs + end + Cn = An/Bn + dI_dp = dK_dp_val * Cn + K * (inv(Bn) * dAn_dp - (An/(Bn^2)) * dBn_dp) + dI_dq = dK_dq_val * Cn + K * (inv(Bn) * dAn_dq - (An/(Bn^2)) * dBn_dq) + Ixpqn = K * Cn + if n >= minapp + denomI = max(abs(Ixpqn), abs(Ixpq), eps(T)) + denomp = max(abs(dI_dp), abs(dI_dp_prev), eps(T)) + denomq = max(abs(dI_dq), abs(dI_dq_prev), eps(T)) + rI = abs(Ixpqn - Ixpq) / denomI + rp = abs(dI_dp - dI_dp_prev) / denomp + rq = abs(dI_dq - dI_dq_prev) / denomq + if max(rI, rp, rq) < ϵ + break end - else - dbn1 = one(T); dan1 = one(T) end - # approximant components - dr1 = dan1 / dbn1 - drp = (danp - dr1 * dbnp) / dbn1 - drq = (danq - dr1 * dbnq) / dbn1 - # shift n-1/n-2 - an2_1, an2_p, an2_q = an1_1, an1_p, an1_q - an1_1, an1_p, an1_q = dan1, danp, danq - bn2_1, bn2_p, bn2_q = bn1_1, bn1_p, bn1_q - bn1_1, bn1_p, bn1_q = dbn1, dbnp, dbnq - # nth approximant - pr = dr1 > 0 ? exp(cc1 + log(dr1)) : zero(T) - I = pr - Ip = muladd(pr, cc2, c0 * drp) - Iq = muladd(pr, cc4, c0 * drq) - # convergence - d1 = max(err, abs(I)); d2 = max(err, abs(Ip)); d4 = max(err, abs(Iq)) - dI = isfinite(prevI) ? abs(prevI - I) / d1 : one(T) - dP = isfinite(prevIp) ? abs(prevIp - Ip) / d2 : one(T) - dQ = isfinite(prevIq) ? abs(prevIq - Iq) / d4 : one(T) - d = max(dI, max(dP, dQ)) - prevI, prevIp, prevIq = I, Ip, Iq + Ixpq = Ixpqn + dI_dp_prev = dI_dp + dI_dq_prev = dI_dq + App = Ap; Bpp = Bp; Ap = An; Bp = Bn + dApp_dp = dAp_dp; dApp_dq = dAp_dq; dBpp_dp = dBp_dp; dBpp_dq = dBp_dq + dAp_dp = dAn_dp; dAp_dq = dAn_dq; dBp_dp = dBn_dp; dBp_dq = dBn_dq end - if swapped - I = one(T) - I - Ip, Iq = -Iq, -Ip + if swap + return oneT - Ixpqn, -dI_dq, -dI_dp, dx + else + return Ixpqn, dI_dp, dI_dq, dx end - return I, Ip, Iq, dI_dx +end + +# Generic wrapper preserving the previous interface/name +function _ibeta_grad_splus(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T=eps(T)*T(1e4)) where {T<:AbstractFloat} + tol = min(err, T(1e-14)) + maxit = max(1000, maxapp) + minit = max(5, minapp) + I, dIa, dIb, dIx = _beta_inc_grad_boik(a, b, x, maxit, minit, tol) + return I, dIa, dIb, dIx end @@ -635,8 +485,12 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number, p, q = beta_inc(a, b, x) # derivatives T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) - _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) - Δp = dIa * convert(T, Δa) + dIb * convert(T, Δb) + dIx * convert(T, Δx) + _, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ + ΔaT::T = Δa isa Real ? T(Δa) : zero(T) + ΔbT::T = Δb isa Real ? T(Δb) : zero(T) + ΔxT::T = Δx isa Real ? T(Δx) : zero(T) + Δp = dIa * ΔaT + dIb * ΔbT + dIx * ΔxT Δq = -Δp Tout = typeof((p, q)) return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq) @@ -648,7 +502,8 @@ function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Numbe Tb = ChainRulesCore.ProjectTo(b) Tx = ChainRulesCore.ProjectTo(x) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) - _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ function beta_inc_pullback(Δ) Δp, Δq = Δ s = Δp - Δq # because q = 1 - p @@ -662,8 +517,13 @@ end function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Number, b::Number, x::Number, y::Number) p, q = beta_inc(a, b, x, y) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) - _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) - Δp = dIa * convert(T, Δa) + dIb * convert(T, Δb) + dIx * (convert(T, Δx) - convert(T, Δy)) + _, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ + ΔaT::T = Δa isa Real ? T(Δa) : zero(T) + ΔbT::T = Δb isa Real ? T(Δb) : zero(T) + ΔxT::T = Δx isa Real ? T(Δx) : zero(T) + ΔyT::T = Δy isa Real ? T(Δy) : zero(T) + Δp = dIa * ΔaT + dIb * ΔbT + dIx * (ΔxT - ΔyT) Δq = -Δp Tout = typeof((p, q)) return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq) @@ -676,7 +536,8 @@ function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Numbe Tx = ChainRulesCore.ProjectTo(x) Ty = ChainRulesCore.ProjectTo(y) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) - _, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ function beta_inc_pullback(Δ) Δp, Δq = Δ s = Δp - Δq @@ -694,14 +555,18 @@ function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Num x, y = beta_inc_inv(a, b, p) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p))) # Implicit differentiation at solved x: I_x(a,b) = p - _, dIa, dIb, _ = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, _ = _ibeta_grad_splus(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_ # ∂I/∂x at solved x via stable log-space expression - dIx_acc = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) - inv_dIx = inv(dIx_acc) - dx_da = -dIa * inv_dIx - dx_db = -dIb * inv_dIx - dx_dp = inv_dIx - Δx = dx_da * convert(T, Δa) + dx_db * convert(T, Δb) + dx_dp * convert(T, Δp) + dIx_acc::T = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) + inv_dIx::T = inv(dIx_acc) + dx_da::T = -dIa * inv_dIx + dx_db::T = -dIb * inv_dIx + dx_dp::T = inv_dIx + ΔaT::T = Δa isa Real ? T(Δa) : zero(T) + ΔbT::T = Δb isa Real ? T(Δb) : zero(T) + ΔpT::T = Δp isa Real ? T(Δp) : zero(T) + Δx = dx_da * ΔaT + dx_db * ΔbT + dx_dp * ΔpT Δy = -Δx Tout = typeof((x, y)) return (x, y), ChainRulesCore.Tangent{Tout}(Δx, Δy) @@ -713,13 +578,14 @@ function ChainRulesCore.rrule(::typeof(beta_inc_inv), a::Number, b::Number, p::N Tb = ChainRulesCore.ProjectTo(b) Tp = ChainRulesCore.ProjectTo(p) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p))) - _, dIa, dIb, _ = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, _ = _ibeta_grad_splus(T(a), T(b), T(x)) + dIa::T = dIa_; dIb::T = dIb_ # ∂I/∂x at solved x via stable log-space expression - dIx_acc = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) - inv_dIx = inv(dIx_acc) - dx_da = -dIa * inv_dIx - dx_db = -dIb * inv_dIx - dx_dp = inv_dIx + dIx_acc::T = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) + inv_dIx::T = inv(dIx_acc) + dx_da::T = -dIa * inv_dIx + dx_db::T = -dIb * inv_dIx + dx_dp::T = inv_dIx function beta_inc_inv_pullback(Δ) Δx, Δy = Δ s = Δx - Δy From bf3dc6c22ff157eefa0addd34734b608ad3d46da Mon Sep 17 00:00:00 2001 From: Oskar Laverny Date: Mon, 13 Oct 2025 23:50:48 +0200 Subject: [PATCH 4/6] Move subfunctions out & add comments --- ext/SpecialFunctionsChainRulesCoreExt.jl | 384 +++++++++++++++-------- 1 file changed, 248 insertions(+), 136 deletions(-) diff --git a/ext/SpecialFunctionsChainRulesCoreExt.jl b/ext/SpecialFunctionsChainRulesCoreExt.jl index e1b75eba..8c864224 100644 --- a/ext/SpecialFunctionsChainRulesCoreExt.jl +++ b/ext/SpecialFunctionsChainRulesCoreExt.jl @@ -315,152 +315,271 @@ end # using a continued-fraction representation of ₂F₁ and differentiating through it. # This is an independent implementation adapted from https://github.com/arzwa/IncBetaDer.jl. -# Generic-typed version for high-precision evaluation -function _beta_inc_grad_boik(a::T, b::T, x::T, - maxapp::Int=200, minapp::Int=3, ϵ::T=convert(T, 1e-12)) where {T<:AbstractFloat} - oneT = one(T); zeroT = zero(T) - if x == oneT - return oneT, zeroT, zeroT, zeroT - elseif x == zeroT - return zeroT, zeroT, zeroT, zeroT +# Generic-typed helpers used by the continued-fraction evaluation of I_x(a,b) +# and its partial derivatives. These implement the scalar prefactor K(x;p,q), +# the auxiliary variable f, the continued-fraction coefficients a_n, b_n, and +# their partial derivatives w.r.t. p (≡ a) and q (≡ b). See Boik & Robinson-Cox (1999). + +function _Kfun(x::T, p::T, q::T) where {T<:AbstractFloat} + # K(x;p,q) = x^p (1-x)^{q-1} / (p * B(p,q)) computed in log-space for stability + return exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - logbeta(p, q)) +end + +function _ffun(x::T, p::T, q::T) where {T<:AbstractFloat} + # f = q x / (p (1-x)) — convenience variable appearing in CF coefficients + return q * x / (p * (1 - x)) +end + +function _a1fun(p::T, q::T, f::T) where {T<:AbstractFloat} + # a₁ coefficient of the continued fraction for ₂F₁ representation + return p * f * (q - 1) / (q * (p + 1)) +end + +function _anfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # a_n coefficient (n ≥ 1) of the continued fraction for ₂F₁ in terms of p=a, q=b, f. + # For n=1, falls back to a₁; for n≥2 uses the closed-form product from the Gauss CF. + n == 1 && return _a1fun(p, q, f) + return p^2 * f^2 * (n - 1) * (p + q + n - 2) * (p + n - 1) * (q - n) / + (q^2 * (p + 2*n - 3) * (p + 2*n - 2)^2 * (p + 2*n - 1)) +end + +function _bnfun(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # b_n coefficient (n ≥ 1) of the continued fraction. Derived for the same CF. + x = 2 * (p * f + 2 * q) * n^2 + 2 * (p * f + 2 * q) * (p - 1) * n + p * q * (p - 2 - p * f) + y = q * (p + 2*n - 2) * (p + 2*n) + return x / y +end + +function _dK_dp(x::T, p::T, q::T, K::T, ψpq::T, ψp::T) where {T<:AbstractFloat} + # ∂K/∂p using digamma identities: d/dp log B(p,q) = ψ(p) - ψ(p+q) + return K * (log(x) - inv(p) + ψpq - ψp) +end + +function _dK_dq(x::T, p::T, q::T, K::T, ψpq::T, ψq::T) where {T<:AbstractFloat} + # ∂K/∂q using identical pattern + K * (log1p(-x) + ψpq - ψq) +end + +function _dK_dpdq(x::T, p::T, q::T) where {T<:AbstractFloat} + # Convenience: compute (∂K/∂p, ∂K/∂q) together with shared ψ(p+q) + ψ = digamma(p + q) + Kf = _Kfun(x, p, q) + dKdp = _dK_dp(x, p, q, Kf, ψ, digamma(p)) + dKdq = _dK_dq(x, p, q, Kf, ψ, digamma(q)) + return dKdp, dKdq +end + +function _da1_dp(p::T, q::T, f::T) where {T<:AbstractFloat} + # ∂a₁/∂p from the closed form of a₁ + return - _a1fun(p, q, f) / (p + 1) +end + +function _dan_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # ∂a_n/∂p via log-derivative: d a_n = a_n * d log a_n; for n=1, uses ∂a₁/∂p + if n == 1 + return _da1_dp(p, q, f) end + an = _anfun(p, q, f, n) + dlog = inv(p + q + n - 2) + inv(p + n - 1) - inv(p + 2*n - 3) - 2 * inv(p + 2*n - 2) - inv(p + 2*n - 1) + return an * dlog +end + +function _da1_dq(p::T, q::T, f::T) where {T<:AbstractFloat} + # ∂a₁/∂q + return _a1fun(p, q, f) / (q - 1) +end + + +function _dan_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # ∂a_n/∂q via log-derivative; for n=1, uses ∂a₁/∂q + if n == 1 + return _da1_dq(p, q, f) + end + an = _anfun(p, q, f, n) + dlog = inv(p + q + n - 2) + inv(q - n) + return an * dlog +end + +function _dbn_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # ∂b_n/∂p via quotient rule on b_n = N/D. + # Note the internal dependence f(p,q)=q x/(p(1-x)) — terms cancel in N as per derivation. + g = p * f + 2 * q + A = 2 * n^2 + 2 * (p - 1) * n + N1 = g * A + N2 = p * q * (p - 2 - p * f) + N = N1 + N2 + D = q * (p + 2*n - 2) * (p + 2*n) + dN1_dp = 2 * n * g + dN2_dp = q * (2 * p - 2) - p * q * f + dN_dp = dN1_dp + dN2_dp + dD_dp = q * (2 * p + 4 * n - 2) + return (dN_dp * D - N * dD_dp) / (D^2) +end + +function _dbn_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} + # ∂b_n/∂q similarly via quotient rule + g = p * f + 2 * q + A = 2 * n^2 + 2 * (p - 1) * n + N1 = g * A + N2 = p * q * (p - 2 - p * f) + N = N1 + N2 + D = q * (p + 2*n - 2) * (p + 2*n) + g_q = p * (f / q) + 2 + dN1_dq = g_q * A + dN2_dq = p * (p - 2 - p * f) - p^2 * f + dN_dq = dN1_dq + dN2_dq + dD_dq = (p + 2*n - 2) * (p + 2*n) + return (dN_dq * D - N * dD_dq) / (D^2) +end + +function _nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) where {T<:AbstractFloat} + # One step of the continuant recurrences: + # A_n = a_n A_{n-2} + b_n A_{n-1} + # B_n = a_n B_{n-2} + b_n B_{n-1} + an = _anfun(p, q, f, n) + bn = _bnfun(p, q, f, n) + An = an * App + bn * Ap + Bn = an * Bpp + bn * Bp + return An, Bn, an, bn +end + +function _dnextapp(an::T, bn::T, dan::T, dbn::T, Xpp::T, Xp::T, dXpp::T, dXp::T) where {T<:AbstractFloat} + # Derivative propagation for the same recurrences (X∈{A,B}) + return dan * Xpp + an * dXpp + dbn * Xp + bn * dXp +end + +function _beta_inc_grad(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T=eps(T)*T(1e4)) where {T<:AbstractFloat} + # Compute I_x(a,b) and partial derivatives (∂I/∂a, ∂I/∂b, ∂I/∂x) + # using a differentiated continued fraction with convergence control. + oneT = one(T) + zeroT = zero(T) + + # 1) Boundary cases for x + x == oneT && return oneT, zeroT, zeroT, zeroT + x == zeroT && return zeroT, zeroT, zeroT, zeroT + + # 2) Clamp iteration/tolerance parameters to robust defaults + ϵ = min(err, T(1e-14)) + maxapp = max(1000, maxapp) + minapp = max(5, minapp) + + # 3) Non-boundary path: precompute ∂I/∂x at original (a,b,x) via stable log form dx = exp((a - oneT) * log(x) + (b - oneT) * log1p(-x) - logbeta(a,b)) - # swap tails if necessary - p = a; q = b; x₀ = x; swap = false + + # 4) Optional tail-swap for symmetry and improved CF convergence: + # if x > a/(a+b), evaluate at (p,q,x₀) = (b,a,1-x) and swap back at the end. + p = a + q = b + x₀ = x + swap = false if x > a / (a + b) - x₀ = oneT - x - p = b - q = a + x₀ = oneT - x + p = b + q = a swap = true end - Kfun(x::T, p::T, q::T) = exp(p * log(x) + (q - oneT) * log1p(-x) - log(p) - logbeta(p, q)) - ffun(x::T, p::T, q::T) = q*x/(p*(oneT - x)) - a1fun(p::T, q::T, f::T) = p*f*(q - oneT)/(q*(p + oneT)) - anfun(p::T, q::T, f::T, n::Int) = n == 1 ? a1fun(p, q, f) : - p^2 * f^2 * (T(n) - oneT) * (p + q + T(n) - T(2)) * (p + T(n) - oneT) * (q - T(n)) / - (q^2 * (p + T(2n) - T(3)) * (p + T(2n) - T(2))^2 * (p + T(2n) - oneT)) - function bnfun(p::T, q::T, f::T, n::Int) - x = T(2)*(p*f + T(2)*q)*T(n)^2 + T(2)*(p*f + T(2)*q)*(p - oneT)*T(n) + p*q*(p - T(2) - p*f) - y = (q * (p + T(2n) - T(2)) * (p + T(2n))) - x/y - end - dK_dp(x::T, p::T, q::T, K::T, ψpq::T, ψp::T) = K*(log(x) - inv(p) + ψpq - ψp) - dK_dq(x::T, p::T, q::T, K::T, ψpq::T, ψq::T) = K*(log1p(-x) + ψpq - ψq) - function dK_dpdq(x::T, p::T, q::T) - ψ = digamma(p+q) - Kf = Kfun(x, p, q) - dKdp = dK_dp(x, p, q, Kf, ψ, digamma(p)) - dKdq = dK_dq(x, p, q, Kf, ψ, digamma(q)) - dKdp, dKdq - end - # a_n derivatives via log-derivative - da1_dp(p::T, q::T, f::T) = -a1fun(p, q, f) / (p + oneT) - function dan_dp(p::T, q::T, f::T, n::Int) - if n == 1 - return da1_dp(p, q, f) - end - an = anfun(p, q, f, n) - dlog = inv(p + q + T(n) - T(2)) + inv(p + T(n) - oneT) - inv(p + T(2n) - T(3)) - T(2) * inv(p + T(2n) - T(2)) - inv(p + T(2n) - oneT) - return an * dlog - end - da1_dq(p::T, q::T, f::T) = a1fun(p, q, f) / (q - oneT) - function dan_dq(p::T, q::T, f::T, n::Int) - if n == 1 - return da1_dq(p, q, f) - end - an = anfun(p, q, f, n) - dlog = inv(p + q + T(n) - T(2)) + inv(q - T(n)) - return an * dlog - end - # b_n derivatives via quotient rule, accounting for f_p=-f/p, f_q=f/q which cancel in N - function dbn_dp(p::T, q::T, f::T, n::Int) - g = p * f + T(2) * q - A = T(2) * T(n)^2 + T(2) * (p - oneT) * T(n) - N1 = g * A - N2 = p * q * (p - T(2) - p * f) - N = N1 + N2 - D = q * (p + T(2n) - T(2)) * (p + T(2n)) - dN1_dp = T(2) * T(n) * g - dN2_dp = q * (T(2) * p - T(2)) - p * q * f - dN_dp = dN1_dp + dN2_dp - dD_dp = q * (T(2) * p + T(4) * T(n) - T(2)) - return (dN_dp * D - N * dD_dp) / (D^2) - end - function dbn_dq(p::T, q::T, f::T, n::Int) - g = p * f + T(2) * q - A = T(2) * T(n)^2 + T(2) * (p - oneT) * T(n) - N1 = g * A - N2 = p * q * (p - T(2) - p * f) - N = N1 + N2 - D = q * (p + T(2n) - T(2)) * (p + T(2n)) - g_q = p * (f / q) + T(2) - dN1_dq = g_q * A - dN2_dq = p * (p - T(2) - p * f) - p^2 * f - dN_dq = dN1_dq + dN2_dq - dD_dq = (p + T(2n) - T(2)) * (p + T(2n)) - return (dN_dq * D - N * dD_dq) / (D^2) - end - _nextapp(f::T, p::T, q::T, n::Int, App::T, Ap::T, Bpp::T, Bp::T) = begin - an = anfun(p, q, f, n) - bn = bnfun(p, q, f, n) - An = an*App + bn*Ap - Bn = an*Bpp + bn*Bp - An, Bn, an, bn - end - _dnextapp(an::T, bn::T, dan::T, dbn::T, Xpp::T, Xp::T, dXpp::T, dXp::T) = dan * Xpp + an * dXpp + dbn * Xp + bn * dXp - - # compute once - K = Kfun(x₀, p, q) - dK_dp_val, dK_dq_val = dK_dpdq(x₀, p, q) - f = ffun(x₀, p, q) - App = oneT; Ap = oneT; Bpp = zeroT; Bp = oneT - dApp_dp = zeroT; dBpp_dp = zeroT; dAp_dp = zeroT; dBp_dp = zeroT - dApp_dq = zeroT; dBpp_dq = zeroT; dAp_dq = zeroT; dBp_dq = zeroT - dI_dp = T(NaN); dI_dq = T(NaN); Ixpq = T(NaN); Ixpqn = T(NaN); dI_dp_prev = T(NaN); dI_dq_prev = T(NaN) + + # 5) Initialize CF state and derivatives + K = _Kfun(x₀, p, q) + dK_dp_val, dK_dq_val = _dK_dpdq(x₀, p, q) + f = _ffun(x₀, p, q) + App = oneT + Ap = oneT + Bpp = zeroT + Bp = oneT + dApp_dp = zeroT + dBpp_dp = zeroT + dAp_dp = zeroT + dBp_dp = zeroT + dApp_dq = zeroT + dBpp_dq = zeroT + dAp_dq = zeroT + dBp_dq = zeroT + dI_dp = T(NaN) + dI_dq = T(NaN) + Ixpq = T(NaN) + Ixpqn = T(NaN) + dI_dp_prev = T(NaN) + dI_dq_prev = T(NaN) + + # 6) Main CF loop (n from 1): update continuants, scale, form current approximant Cn=A_n/B_n + # and its derivatives to update I and ∂I/∂(p,q). Stop on relative convergence of all. for n=1:maxapp + + # Update continuants. An, Bn, an, bn = _nextapp(f, p, q, n, App, Ap, Bpp, Bp) - dan = dan_dp(p, q, f, n); dbn = dbn_dp(p, q, f, n) - dAn_dp = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dp, dAp_dp) - dBn_dp = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dp, dBp_dp) - dan = dan_dq(p, q, f, n); dbn = dbn_dq(p, q, f, n) - dAn_dq = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dq, dAp_dq) - dBn_dq = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dq, dBp_dq) - # normalize states to control growth/underflow (scale-invariant) + dan = _dan_dp(p, q, f, n) + dbn = _dbn_dp(p, q, f, n) + dAn_dp = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dp, dAp_dp) + dBn_dp = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dp, dBp_dp) + dan = _dan_dq(p, q, f, n) + dbn = _dbn_dq(p, q, f, n) + dAn_dq = _dnextapp(an, bn, dan, dbn, App, Ap, dApp_dq, dAp_dq) + dBn_dq = _dnextapp(an, bn, dan, dbn, Bpp, Bp, dBpp_dq, dBp_dq) + + # Normalize states to control growth/underflow (scale-invariant transform) s = maximum((abs(An), abs(Bn), abs(Ap), abs(Bp), abs(App), abs(Bpp))) if isfinite(s) && s > zeroT - invs = inv(s) - An *= invs; Bn *= invs - Ap *= invs; Bp *= invs - App *= invs; Bpp *= invs - dAn_dp *= invs; dBn_dp *= invs - dAn_dq *= invs; dBn_dq *= invs - dAp_dp *= invs; dBp_dp *= invs - dApp_dp *= invs; dBpp_dp *= invs - dAp_dq *= invs; dBp_dq *= invs - dApp_dq *= invs; dBpp_dq *= invs + invs = inv(s) + An *= invs + Bn *= invs + Ap *= invs + Bp *= invs + App *= invs + Bpp *= invs + dAn_dp *= invs + dBn_dp *= invs + dAn_dq *= invs + dBn_dq *= invs + dAp_dp *= invs + dBp_dp *= invs + dApp_dp *= invs + dBpp_dp *= invs + dAp_dq *= invs + dBp_dq *= invs + dApp_dq *= invs + dBpp_dq *= invs end - Cn = An/Bn + + # Form current approximant Cn=A_n/B_n and its derivatives + Cn = An/Bn dI_dp = dK_dp_val * Cn + K * (inv(Bn) * dAn_dp - (An/(Bn^2)) * dBn_dp) dI_dq = dK_dq_val * Cn + K * (inv(Bn) * dAn_dq - (An/(Bn^2)) * dBn_dq) Ixpqn = K * Cn + + # Decide convergence: if n >= minapp + # Relative convergence for I, ∂I/∂p, ∂I/∂q (guards against tiny denominators) denomI = max(abs(Ixpqn), abs(Ixpq), eps(T)) denomp = max(abs(dI_dp), abs(dI_dp_prev), eps(T)) denomq = max(abs(dI_dq), abs(dI_dq_prev), eps(T)) - rI = abs(Ixpqn - Ixpq) / denomI - rp = abs(dI_dp - dI_dp_prev) / denomp - rq = abs(dI_dq - dI_dq_prev) / denomq + rI = abs(Ixpqn - Ixpq) / denomI + rp = abs(dI_dp - dI_dp_prev) / denomp + rq = abs(dI_dq - dI_dq_prev) / denomq if max(rI, rp, rq) < ϵ break end end - Ixpq = Ixpqn + Ixpq = Ixpqn dI_dp_prev = dI_dp dI_dq_prev = dI_dq - App = Ap; Bpp = Bp; Ap = An; Bp = Bn - dApp_dp = dAp_dp; dApp_dq = dAp_dq; dBpp_dp = dBp_dp; dBpp_dq = dBp_dq - dAp_dp = dAn_dp; dAp_dq = dAn_dq; dBp_dp = dBn_dp; dBp_dq = dBn_dq + + # Shift CF state for next iteration + App = Ap + Bpp = Bp + Ap = An + Bp = Bn + dApp_dp = dAp_dp + dApp_dq = dAp_dq + dBpp_dp = dBp_dp + dBpp_dq = dBp_dq + dAp_dp = dAn_dp + dAp_dq = dAn_dq + dBp_dp = dBn_dp + dBp_dq = dBn_dq end + + # 7) Undo tail-swap if applied; ∂I/∂x is the pdf at original (a,b,x) if swap return oneT - Ixpqn, -dI_dq, -dI_dp, dx else @@ -468,14 +587,7 @@ function _beta_inc_grad_boik(a::T, b::T, x::T, end end -# Generic wrapper preserving the previous interface/name -function _ibeta_grad_splus(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T=eps(T)*T(1e4)) where {T<:AbstractFloat} - tol = min(err, T(1e-14)) - maxit = max(1000, maxapp) - minit = max(5, minapp) - I, dIa, dIb, dIx = _beta_inc_grad_boik(a, b, x, maxit, minit, tol) - return I, dIa, dIb, dIx -end + @@ -485,7 +597,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number, p, q = beta_inc(a, b, x) # derivatives T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) - _, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, dIx_ = _beta_inc_grad(T(a), T(b), T(x)) dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ ΔaT::T = Δa isa Real ? T(Δa) : zero(T) ΔbT::T = Δb isa Real ? T(Δb) : zero(T) @@ -502,7 +614,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Numbe Tb = ChainRulesCore.ProjectTo(b) Tx = ChainRulesCore.ProjectTo(x) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x))) - _, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, dIx_ = _beta_inc_grad(T(a), T(b), T(x)) dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ function beta_inc_pullback(Δ) Δp, Δq = Δ @@ -517,7 +629,7 @@ end function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Number, b::Number, x::Number, y::Number) p, q = beta_inc(a, b, x, y) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) - _, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, dIx_ = _beta_inc_grad(T(a), T(b), T(x)) dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ ΔaT::T = Δa isa Real ? T(Δa) : zero(T) ΔbT::T = Δb isa Real ? T(Δb) : zero(T) @@ -536,7 +648,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Numbe Tx = ChainRulesCore.ProjectTo(x) Ty = ChainRulesCore.ProjectTo(y) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y))) - _, dIa_, dIb_, dIx_ = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, dIx_ = _beta_inc_grad(T(a), T(b), T(x)) dIa::T = dIa_; dIb::T = dIb_; dIx::T = dIx_ function beta_inc_pullback(Δ) Δp, Δq = Δ @@ -555,7 +667,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Num x, y = beta_inc_inv(a, b, p) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p))) # Implicit differentiation at solved x: I_x(a,b) = p - _, dIa_, dIb_, _ = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, _ = _beta_inc_grad(T(a), T(b), T(x)) dIa::T = dIa_; dIb::T = dIb_ # ∂I/∂x at solved x via stable log-space expression dIx_acc::T = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) @@ -578,7 +690,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc_inv), a::Number, b::Number, p::N Tb = ChainRulesCore.ProjectTo(b) Tp = ChainRulesCore.ProjectTo(p) T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(p))) - _, dIa_, dIb_, _ = _ibeta_grad_splus(T(a), T(b), T(x)) + _, dIa_, dIb_, _ = _beta_inc_grad(T(a), T(b), T(x)) dIa::T = dIa_; dIb::T = dIb_ # ∂I/∂x at solved x via stable log-space expression dIx_acc::T = exp(muladd(T(a) - one(T), log(T(x)), muladd(T(b) - one(T), log1p(-T(x)), -logbeta(T(a), T(b))))) From 39eb68a593650e0c8e559ca0b493909ec95219e9 Mon Sep 17 00:00:00 2001 From: Oskar Laverny Date: Tue, 14 Oct 2025 00:45:44 +0200 Subject: [PATCH 5/6] Expand test points coverage --- test/chainrules.jl | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 4b3006fd..1b012976 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -227,9 +227,28 @@ end @testset "incomplete beta: basic test_frule/test_rrule" begin - # Use a small, representative set of interior points (avoid endpoints for FD) - test_points = (0.2, 0.5, 0.8) - ab = (0.7, 2.5) + # Use an expanded set of interior points (avoid endpoints for FD) to exercise many branches: + # Rationale for x values: + # - Include values around 0.1, 0.3, 0.5, 0.7, 0.9 to trigger different code paths. + # - Include 0.14 and 0.28 to straddle the bx ≤ 0.7 power-series threshold for b ≈ 5 and 2.5. + # - Include values near 0.5 (0.49, 0.51) to probe near-symmetry and tail swaps. + # - Include additional midpoints to increase chance that x ≈ a/(a+b) for some (a,b), which makes λ ≈ 0 + # in the large-parameter regime (key for choosing symmetric asymptotics when min(a,b) > 100). + # - Add a few more around 0.6–0.8 to exercise continued fraction vs. asymptotics for large (a,b). + test_points = ( + 0.05, 0.08, 0.10, 0.12, 0.14, 0.18, 0.20, 0.22, 0.26, + 0.28, 0.30, 0.32, 0.35, 0.38, 0.40, 0.42, 0.45, + 0.49, 0.50, 0.51, 0.55, 0.58, 0.60, 0.62, 0.65, + 0.68, 0.70, 0.72, 0.76, 0.80, 0.85, 0.90 + ) + # Rationale for a,b values: + # - <1: 0.4, 0.6 to stress small-parameter power series branches. + # - Near 1: 0.9, 1.1 to test branch boundaries and continuity across a≈1, b≈1. + # - Moderate: 2.5, 5.0 where multiple algorithm choices engage based on x and bx. + # - Large (≥15, ≥40) to drive large-parameter regimes: 16.0, 45.0. + # - Very large (≫100): 100.5, 150.0 to ensure symmetric vs asymmetric asymptotics are exercised when λ + # is small/large, and continued fractions are robust for large shapes. + ab = (0.4, 0.6, 0.9, 1.1, 2.5, 5.0, 16.0, 45.0, 100.5, 150.0) # 3-argument beta_inc(a,b,x) for a in ab, b in ab, x in test_points @@ -256,8 +275,12 @@ end @testset "4-arg beta_inc identities (y = 1 - x)" begin - test_points = (0.2, 0.5, 0.8) - ab = (0.7, 2.5) + # Exercise more regimes while keeping y = 1 - x constraint. + # Same rationale as above for x and (a,b) coverage. + test_points = ( + 0.05, 0.10, 0.12, 0.14, 0.20, 0.28, 0.35, 0.40, 0.49, 0.50, 0.51, 0.60, 0.65, 0.70, 0.72, 0.80, 0.90 + ) + ab = (0.4, 0.6, 0.9, 1.1, 2.5, 5.0, 16.0, 45.0, 100.5, 150.0) for a in ab, b in ab, x in test_points 0.0 < x < 1.0 || continue From 6279c50a8af3ad8767c6c9cd32421106e9a6e16c Mon Sep 17 00:00:00 2001 From: Oskar Laverny Date: Tue, 14 Oct 2025 00:46:03 +0200 Subject: [PATCH 6/6] Fix behavior in extreme cases --- ext/SpecialFunctionsChainRulesCoreExt.jl | 36 ++++++++++++++++++------ 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/ext/SpecialFunctionsChainRulesCoreExt.jl b/ext/SpecialFunctionsChainRulesCoreExt.jl index 8c864224..9db2ef26 100644 --- a/ext/SpecialFunctionsChainRulesCoreExt.jl +++ b/ext/SpecialFunctionsChainRulesCoreExt.jl @@ -391,13 +391,25 @@ end function _dan_dq(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} - # ∂a_n/∂q via log-derivative; for n=1, uses ∂a₁/∂q + # ∂a_n/∂q avoiding the removable singularity at q ≈ n for integer q. + # For n=1, defer to the specific a₁ derivative. if n == 1 return _da1_dq(p, q, f) end - an = _anfun(p, q, f, n) - dlog = inv(p + q + n - 2) + inv(q - n) - return an * dlog + # Use the simplified closed-form of a_n that eliminates explicit q^2 via f: + # a_n = (x/(1-x))^2 * (n-1) * (p+n-1) * (p+q+n-2) * (q-n) / D(p,n) + # where D(p,n) = (p+2n-3)*(p+2n-2)^2*(p+2n-1) and (x/(1-x)) = p*f/q. + # Differentiate only the q-dependent factor G(q) = (p+q+n-2)*(q-n): + # dG/dq = (q-n) + (p+q+n-2) = p + 2q - 2. + + # This is equivalent to + # return _anfun(p,q,f,n) * (inv(p+q+n-2) + inv(q-n)) + # but more precise. + + pfq = (p * f) / q + C = (pfq * pfq) * (n - 1) * (p + n - 1) / + ((p + 2*n - 3) * (p + 2*n - 2)^2 * (p + 2*n - 1)) + return C * (p + 2*q - 2) end function _dbn_dp(p::T, q::T, f::T, n::Int) where {T<:AbstractFloat} @@ -541,11 +553,17 @@ function _beta_inc_grad(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T dBpp_dq *= invs end - # Form current approximant Cn=A_n/B_n and its derivatives - Cn = An/Bn - dI_dp = dK_dp_val * Cn + K * (inv(Bn) * dAn_dp - (An/(Bn^2)) * dBn_dp) - dI_dq = dK_dq_val * Cn + K * (inv(Bn) * dAn_dq - (An/(Bn^2)) * dBn_dq) - Ixpqn = K * Cn + # Form current approximant Cn=A_n/B_n and its derivatives. + # Guard against tiny/zero Bn to avoid NaNs/Inf in divisions. + tiny = sqrt(eps(T)) + absBn = abs(Bn) + sgnBn = ifelse(Bn >= zeroT, oneT, -oneT) + invBn = absBn > tiny && isfinite(absBn) ? inv(Bn) : inv(sgnBn * tiny) + Cn = An * invBn + invBn2 = invBn * invBn + dI_dp = dK_dp_val * Cn + K * (invBn * dAn_dp - (An * invBn2) * dBn_dp) + dI_dq = dK_dq_val * Cn + K * (invBn * dAn_dq - (An * invBn2) * dBn_dq) + Ixpqn = K * Cn # Decide convergence: if n >= minapp