1313ChainRulesCore. @non_differentiable Integrals. checkkwargs (kwargs... )
1414ChainRulesCore. @non_differentiable Integrals. isinplace (f, args... ) # fixes #99
1515ChainRulesCore. @non_differentiable Integrals. init_cacheval (alg, prob)
16+ ChainRulesCore. @non_differentiable Integrals. substitute_f (args... ) # use ∂f/∂p instead
17+ ChainRulesCore. @non_differentiable Integrals. substitute_v (args... ) # TODO for ∂f/∂u
18+ ChainRulesCore. @non_differentiable Integrals. substitute_bv (args... ) # TODO for ∂f/∂u
1619
17- function ChainRulesCore. rrule (:: typeof (Integrals. __solve), cache:: Integrals.IntegralCache ,
18- alg:: Integrals.ChangeOfVariables , sensealg, udomain, p;
19- kwargs... )
20- _cache, vdomain = Integrals. _change_variables (cache, alg, sensealg, udomain, p)
21- sol, back = Zygote. pullback ((args... ) -> Integrals. __solve (args... ; kwargs... ),
22- _cache, alg. alg, sensealg, vdomain, p)
23- function change_of_variables_pullback (Δ)
24- return (NoTangent (), back (Δ)... )
20+ # TODO move this adjoint to SciMLBase
21+ function ChainRulesCore. rrule (
22+ :: typeof (SciMLBase. build_solution), prob:: IntegralProblem , alg, u, resid; kwargs... )
23+ function build_integral_solution_pullback (Δ)
24+ return NoTangent (), NoTangent (), NoTangent (), Δ, NoTangent ()
2525 end
26- prob = Integrals. build_problem (cache)
27- _sol = SciMLBase. build_solution (
28- prob, alg. alg, sol. u, sol. resid, chi = sol. chi, retcode = sol. retcode, stats = sol. stats)
29- return _sol, change_of_variables_pullback
26+ return SciMLBase. build_solution (prob, alg, u, resid; kwargs... ),
27+ build_integral_solution_pullback
3028end
3129
32- # we will need to implement the following adjoints when we compute ∂f/∂u
33- function ChainRulesCore. rrule (:: typeof (Integrals. substitute_v), args... )
34- function substitute_v_pullback (_)
35- return NoTangent (), ntuple (_ -> NoTangent (), length (args))...
36- end
37- return Integrals. substitute_v (args... ), substitute_v_pullback
38- end
39- function ChainRulesCore. rrule (:: typeof (Integrals. substitute_bv), args... )
40- function substitute_bv_pullback (_)
41- return NoTangent (), ntuple (_ -> NoTangent (), length (args))...
42- end
43- return Integrals. substitute_bv (args... ), substitute_bv_pullback
44- end
4530function ChainRulesCore. rrule (:: typeof (Integrals. _evaluate!), f, y, u, p)
4631 out, back = Zygote. pullback (y, u, p) do y, u, p
4732 b = Zygote. Buffer (y)
@@ -51,6 +36,16 @@ function ChainRulesCore.rrule(::typeof(Integrals._evaluate!), f, y, u, p)
5136 out, Δ -> (NoTangent (), NoTangent (), back (Δ)... )
5237end
5338
39+ function ChainRulesCore. rrule (:: typeof (Integrals. u2t), lb, ub)
40+ tlb, tub = out = Integrals. u2t (lb, ub)
41+ function u2t_pullback (Δ)
42+ _, lbjac = Integrals. t2ujac (tlb, lb, ub)
43+ _, ubjac = Integrals. t2ujac (tub, lb, ub)
44+ return NoTangent (), Δ[1 ] / lbjac, Δ[2 ] / ubjac
45+ end
46+ return out, u2t_pullback
47+ end
48+
5449function ChainRulesCore. rrule (:: typeof (Integrals. __solvebp), cache, alg, sensealg, domain,
5550 p;
5651 kwargs... )
0 commit comments