|
| 1 | +module SimpleNonlinearSolveReverseDiffExt |
| 2 | + |
| 3 | +using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve |
| 4 | +import ReverseDiff: TrackedArray, TrackedReal |
| 5 | +import SimpleNonlinearSolve: __internal_solve_up |
| 6 | + |
| 7 | +function __internal_solve_up( |
| 8 | + prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, |
| 9 | + p::TrackedArray, p_changed, alg, args...; kwargs...) |
| 10 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 11 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 12 | +end |
| 13 | + |
| 14 | +function __internal_solve_up( |
| 15 | + prob::NonlinearProblem, sensealg, u0, u0_changed, |
| 16 | + p::TrackedArray, p_changed, alg, args...; kwargs...) |
| 17 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 18 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 19 | +end |
| 20 | + |
| 21 | +function __internal_solve_up( |
| 22 | + prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, |
| 23 | + p, p_changed, alg, args...; kwargs...) |
| 24 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 25 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 26 | +end |
| 27 | + |
| 28 | +function __internal_solve_up(prob::NonlinearProblem, sensealg, |
| 29 | + u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal}, |
| 30 | + p_changed, alg, args...; kwargs...) |
| 31 | + return __internal_solve_up( |
| 32 | + prob, sensealg, ArrayInterface.aos_to_soa(u0), true, |
| 33 | + ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
| 34 | +end |
| 35 | + |
| 36 | +function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed, |
| 37 | + p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) |
| 38 | + return __internal_solve_up( |
| 39 | + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
| 40 | +end |
| 41 | + |
| 42 | +function __internal_solve_up(prob::NonlinearProblem, sensealg, |
| 43 | + u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...) |
| 44 | + return __internal_solve_up( |
| 45 | + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
| 46 | +end |
| 47 | + |
| 48 | +ReverseDiff.@grad function __internal_solve_up( |
| 49 | + prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) |
| 50 | + out, ∇internal = DiffEqBase._solve_adjoint( |
| 51 | + prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), |
| 52 | + SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...) |
| 53 | + function ∇__internal_solve_up(_args...) |
| 54 | + ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) |
| 55 | + return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) |
| 56 | + end |
| 57 | + return Array(out), ∇__internal_solve_up |
| 58 | +end |
| 59 | + |
| 60 | +end |
0 commit comments