|
1 | 1 | module SimpleNonlinearSolveReverseDiffExt |
2 | 2 |
|
3 | | -using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve |
4 | | -import ReverseDiff: TrackedArray, TrackedReal |
| 3 | +using ArrayInterface: ArrayInterface |
| 4 | +using DiffEqBase: DiffEqBase |
| 5 | +using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal |
| 6 | +using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem |
| 7 | +using SimpleNonlinearSolve: SimpleNonlinearSolve |
5 | 8 | import SimpleNonlinearSolve: __internal_solve_up |
6 | 9 |
|
7 | | -function __internal_solve_up( |
8 | | - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, |
9 | | - p::TrackedArray, p_changed, alg, args...; kwargs...) |
10 | | - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
11 | | - u0_changed, p, p_changed, alg, args...; kwargs...) |
12 | | -end |
| 10 | +for pType in (NonlinearProblem, NonlinearLeastSquaresProblem) |
| 11 | + @eval begin |
| 12 | + function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed, |
| 13 | + p::TrackedArray, p_changed, alg, args...; kwargs...) |
| 14 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 15 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 16 | + end |
13 | 17 |
|
14 | | -function __internal_solve_up( |
15 | | - prob::NonlinearProblem, sensealg, u0, u0_changed, |
16 | | - p::TrackedArray, p_changed, alg, args...; kwargs...) |
17 | | - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
18 | | - u0_changed, p, p_changed, alg, args...; kwargs...) |
19 | | -end |
| 18 | + function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed, |
| 19 | + p::TrackedArray, p_changed, alg, args...; kwargs...) |
| 20 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 21 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 22 | + end |
20 | 23 |
|
21 | | -function __internal_solve_up( |
22 | | - prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed, |
23 | | - p, p_changed, alg, args...; kwargs...) |
24 | | - return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
25 | | - u0_changed, p, p_changed, alg, args...; kwargs...) |
26 | | -end |
| 24 | + function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, |
| 25 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 26 | + return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0, |
| 27 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 28 | + end |
27 | 29 |
|
28 | | -function __internal_solve_up(prob::NonlinearProblem, sensealg, |
29 | | - u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal}, |
30 | | - p_changed, alg, args...; kwargs...) |
31 | | - return __internal_solve_up( |
32 | | - prob, sensealg, ArrayInterface.aos_to_soa(u0), true, |
33 | | - ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
34 | | -end |
| 30 | + function __internal_solve_up( |
| 31 | + prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed, |
| 32 | + p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) |
| 33 | + return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true, |
| 34 | + ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
| 35 | + end |
35 | 36 |
|
36 | | -function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed, |
37 | | - p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) |
38 | | - return __internal_solve_up( |
39 | | - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
40 | | -end |
| 37 | + function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed, |
| 38 | + p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...) |
| 39 | + return __internal_solve_up( |
| 40 | + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), |
| 41 | + true, alg, args...; kwargs...) |
| 42 | + end |
41 | 43 |
|
42 | | -function __internal_solve_up(prob::NonlinearProblem, sensealg, |
43 | | - u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...) |
44 | | - return __internal_solve_up( |
45 | | - prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...) |
46 | | -end |
| 44 | + function __internal_solve_up( |
| 45 | + prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, |
| 46 | + u0_changed, p, p_changed, alg, args...; kwargs...) |
| 47 | + return __internal_solve_up( |
| 48 | + prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), |
| 49 | + true, alg, args...; kwargs...) |
| 50 | + end |
47 | 51 |
|
48 | | -ReverseDiff.@grad function __internal_solve_up( |
49 | | - prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) |
50 | | - out, ∇internal = DiffEqBase._solve_adjoint( |
51 | | - prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), |
52 | | - SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...) |
53 | | - function ∇__internal_solve_up(_args...) |
54 | | - ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) |
55 | | - return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) |
| 52 | + ReverseDiff.@grad function __internal_solve_up( |
| 53 | + prob::$(pType), sensealg, u0, u0_changed, |
| 54 | + p, p_changed, alg, args...; kwargs...) |
| 55 | + out, ∇internal = DiffEqBase._solve_adjoint( |
| 56 | + prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), |
| 57 | + ReverseDiffOriginator(), alg, args...; kwargs...) |
| 58 | + function ∇__internal_solve_up(_args...) |
| 59 | + ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...) |
| 60 | + return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) |
| 61 | + end |
| 62 | + return Array(out), ∇__internal_solve_up |
| 63 | + end |
56 | 64 | end |
57 | | - return Array(out), ∇__internal_solve_up |
58 | 65 | end |
59 | 66 |
|
60 | 67 | end |
0 commit comments