Skip to content

Commit 95e27e2

Browse files
committed
attempt to fix derivative tests
1 parent 5d1bebd commit 95e27e2

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

ext/IntegralsZygoteExt.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,19 @@ ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)
1414
ChainRulesCore.@non_differentiable Integrals.isinplace(f, args...) # fixes #99
1515
ChainRulesCore.@non_differentiable Integrals.init_cacheval(alg, prob)
1616

17-
function ChainRulesCore.rrule(::typeof(Integrals.transformation_if_inf), f, domain)
18-
function transformation_if_inf_pullback(Δ)
19-
return NoTangent(), Δ...
17+
function ChainRulesCore.rrule(::typeof(Integrals.__solve), cache::Integrals.IntegralCache,
18+
alg::Integrals.ChangeOfVariables, sensealg, udomain, p;
19+
kwargs...)
20+
_cache, vdomain = Integrals._change_variables(cache, alg, sensealg, udomain, p)
21+
sol, back = Zygote.pullback((args...) -> Integrals.__solve(args...; kwargs...),
22+
_cache, alg.alg, sensealg, vdomain, p)
23+
function change_of_variables_pullback(Δ)
24+
return (NoTangent(), back(Δ)...)
2025
end
21-
return Integrals.transformation_if_inf(f, domain), transformation_if_inf_pullback
26+
prob = Integrals.build_problem(cache)
27+
_sol = SciMLBase.build_solution(
28+
prob, alg.alg, sol.u, sol.resid, chi = sol.chi, retcode = sol.retcode, stats = sol.stats)
29+
return _sol, change_of_variables_pullback
2230
end
2331

2432
# we will need to implement the following adjoints when we compute ∂f/∂u

src/Integrals.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ end
7878

7979
function __solve(cache::IntegralCache, alg::ChangeOfVariables, sensealg, udomain, p;
8080
kwargs...)
81+
_cache, vdomain = _change_variables(cache, alg, sensealg, udomain, p)
82+
sol = __solve(_cache, alg.alg, sensealg, vdomain, p; kwargs...)
83+
prob = build_problem(cache)
84+
return SciMLBase.build_solution(
85+
prob, alg.alg, sol.u, sol.resid, chi = sol.chi, retcode = sol.retcode, stats = sol.stats)
86+
end
87+
88+
function _change_variables(cache, alg, sensealg, udomain, p)
8189
cacheval = cache.cacheval.alg
8290
g, vdomain = alg.fu2gv(cache.f, udomain)
8391
_cache = IntegralCache(Val(isinplace(g)),
@@ -89,7 +97,7 @@ function __solve(cache::IntegralCache, alg::ChangeOfVariables, sensealg, udomain
8997
sensealg,
9098
cache.kwargs,
9199
cacheval)
92-
return __solve(_cache, alg.alg, sensealg, vdomain, p; kwargs...)
100+
return _cache, vdomain
93101
end
94102

95103
function get_prototype(prob::IntegralProblem)

0 commit comments

Comments
 (0)