@@ -7,59 +7,61 @@ using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresP
77using SimpleNonlinearSolve: SimpleNonlinearSolve
88import SimpleNonlinearSolve: __internal_solve_up
99
10- function __internal_solve_up (
11- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
12- u0:: TrackedArray , u0_changed, p:: TrackedArray , p_changed, alg, args... ; kwargs... )
13- return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
14- u0_changed, p, p_changed, alg, args... ; kwargs... )
15- end
10+ for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
11+ @eval begin
12+ function __internal_solve_up (prob:: $ (pType), sensealg, u0:: TrackedArray , u0_changed,
13+ p:: TrackedArray , p_changed, alg, args... ; kwargs... )
14+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
15+ u0_changed, p, p_changed, alg, args... ; kwargs... )
16+ end
1617
17- function __internal_solve_up (
18- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
19- u0, u0_changed, p:: TrackedArray , p_changed, alg, args... ; kwargs... )
20- return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
21- u0_changed, p, p_changed, alg, args... ; kwargs... )
22- end
18+ function __internal_solve_up (prob:: $ (pType), sensealg, u0, u0_changed,
19+ p:: TrackedArray , p_changed, alg, args... ; kwargs... )
20+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
21+ u0_changed, p, p_changed, alg, args... ; kwargs... )
22+ end
2323
24- function __internal_solve_up (
25- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} , sensealg,
26- u0:: TrackedArray , u0_changed, p, p_changed, alg, args... ; kwargs... )
27- return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
28- u0_changed, p, p_changed, alg, args... ; kwargs... )
29- end
24+ function __internal_solve_up (prob:: $ (pType), sensealg, u0:: TrackedArray ,
25+ u0_changed, p, p_changed, alg, args... ; kwargs... )
26+ return ReverseDiff. track (__internal_solve_up, prob, sensealg, u0,
27+ u0_changed, p, p_changed, alg, args... ; kwargs... )
28+ end
3029
31- function __internal_solve_up (prob :: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
32- sensealg, u0:: AbstractArray{<:TrackedReal} , u0_changed,
33- p:: AbstractArray{<:TrackedReal} , p_changed, alg, args... ; kwargs... )
34- return __internal_solve_up (prob, sensealg, ArrayInterface. aos_to_soa (u0), true ,
35- ArrayInterface. aos_to_soa (p), true , alg, args... ; kwargs... )
36- end
30+ function __internal_solve_up (
31+ prob :: $ (pType), sensealg, u0:: AbstractArray{<:TrackedReal} , u0_changed,
32+ p:: AbstractArray{<:TrackedReal} , p_changed, alg, args... ; kwargs... )
33+ return __internal_solve_up (prob, sensealg, ArrayInterface. aos_to_soa (u0), true ,
34+ ArrayInterface. aos_to_soa (p), true , alg, args... ; kwargs... )
35+ end
3736
38- function __internal_solve_up (
39- prob :: Union{NonlinearProblem, NonlinearLeastSquaresProblem } , sensealg, u0,
40- u0_changed, p :: AbstractArray{<:TrackedReal} , p_changed, alg, args ... ; kwargs ... )
41- return __internal_solve_up ( prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
42- true , alg, args... ; kwargs... )
43- end
37+ function __internal_solve_up (prob :: $ (pType), sensealg, u0, u0_changed,
38+ p :: AbstractArray{<:TrackedReal } , p_changed, alg, args ... ; kwargs ... )
39+ return __internal_solve_up (
40+ prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
41+ true , alg, args... ; kwargs... )
42+ end
4443
45- function __internal_solve_up (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
46- sensealg, u0:: AbstractArray{<:TrackedReal} ,
47- u0_changed, p, p_changed, alg, args... ; kwargs... )
48- return __internal_solve_up (prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
49- true , alg, args... ; kwargs... )
50- end
44+ function __internal_solve_up (
45+ prob:: $ (pType), sensealg, u0:: AbstractArray{<:TrackedReal} ,
46+ u0_changed, p, p_changed, alg, args... ; kwargs... )
47+ return __internal_solve_up (
48+ prob, sensealg, u0, true , ArrayInterface. aos_to_soa (p),
49+ true , alg, args... ; kwargs... )
50+ end
5151
52- ReverseDiff. @grad function __internal_solve_up (
53- prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
54- sensealg, u0, u0_changed, p, p_changed, alg, args... ; kwargs... )
55- out, ∇internal = DiffEqBase. _solve_adjoint (
56- prob, sensealg, ReverseDiff. value (u0), ReverseDiff. value (p),
57- ReverseDiffOriginator (), alg, args... ; kwargs... )
58- function ∇__internal_solve_up (_args... )
59- ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal (_args... )
60- return (∂prob, ∂sensealg, ∂u0, nothing , ∂p, nothing , nothing , ∂args... )
52+ ReverseDiff. @grad function __internal_solve_up (
53+ prob:: $ (pType), sensealg, u0, u0_changed,
54+ p, p_changed, alg, args... ; kwargs... )
55+ out, ∇internal = DiffEqBase. _solve_adjoint (
56+ prob, sensealg, ReverseDiff. value (u0), ReverseDiff. value (p),
57+ ReverseDiffOriginator (), alg, args... ; kwargs... )
58+ function ∇__internal_solve_up (_args... )
59+ ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal (_args... )
60+ return (∂prob, ∂sensealg, ∂u0, nothing , ∂p, nothing , nothing , ∂args... )
61+ end
62+ return Array (out), ∇__internal_solve_up
63+ end
6164 end
62- return Array (out), ∇__internal_solve_up
6365end
6466
6567end
0 commit comments