@@ -14,11 +14,19 @@ ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)
1414ChainRulesCore. @non_differentiable Integrals. isinplace (f, args... ) # fixes #99
1515ChainRulesCore. @non_differentiable Integrals. init_cacheval (alg, prob)
1616
17- function ChainRulesCore. rrule (:: typeof (Integrals. transformation_if_inf), f, domain)
18- function transformation_if_inf_pullback (Δ)
19- return NoTangent (), Δ...
17+ function ChainRulesCore. rrule (:: typeof (Integrals. __solve), cache:: Integrals.IntegralCache ,
18+ alg:: Integrals.ChangeOfVariables , sensealg, udomain, p;
19+ kwargs... )
20+ _cache, vdomain = Integrals. _change_variables (cache, alg, sensealg, udomain, p)
21+ sol, back = Zygote. pullback ((args... ) -> Integrals. __solve (args... ; kwargs... ),
22+ _cache, alg. alg, sensealg, vdomain, p)
23+ function change_of_variables_pullback (Δ)
24+ return (NoTangent (), back (Δ)... )
2025 end
21- return Integrals. transformation_if_inf (f, domain), transformation_if_inf_pullback
26+ prob = Integrals. build_problem (cache)
27+ _sol = SciMLBase. build_solution (
28+ prob, alg. alg, sol. u, sol. resid, chi = sol. chi, retcode = sol. retcode, stats = sol. stats)
29+ return _sol, change_of_variables_pullback
2230end
2331
2432# we will need to implement the following adjoints when we compute ∂f/∂u
0 commit comments