1- function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
2- f = prob. f
1+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
2+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
3+ alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P, iip}
4+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
5+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
6+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
7+ sol. original)
8+ end
9+
10+ function __nlsolve_ad (prob:: NonlinearProblem{uType, iip} , alg, args... ;
11+ kwargs... ) where {uType, iip}
312 p = value (prob. p)
4- if prob isa IntervalNonlinearProblem
5- tspan = value .(prob. tspan)
6- newprob = IntervalNonlinearProblem (f, tspan, p; prob. kwargs... )
7- else
8- u0 = value (prob. u0)
9- newprob = NonlinearProblem (f, u0, p; prob. kwargs... )
10- end
13+ newprob = NonlinearProblem (prob. f, value (prob. u0), p; prob. kwargs... )
1114
1215 sol = solve (newprob, alg, args... ; kwargs... )
1316
1417 uu = sol. u
15- f_p = scalar_nlsolve_ ∂f_∂p (f, uu, p)
16- f_x = scalar_nlsolve_ ∂f_∂u (f, uu, p)
18+ f_p = __nlsolve_ ∂f_∂p (prob, prob . f, uu, p)
19+ f_x = __nlsolve_ ∂f_∂u (prob, prob . f, uu, p)
1720
18- z_arr = - inv ( f_x) * f_p
21+ z_arr = - f_x \ f_p
1922
2023 pp = prob. p
2124 sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
@@ -30,58 +33,66 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3033 return sol, partials
3134end
3235
33- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
34- false , <: Dual{T, V, P} }, alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ;
35- kwargs... ) where {T, V, P}
36- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
37- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
38- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
39- end
40-
41- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
42- false , <: AbstractArray{<:Dual{T, V, P}} },
43- alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P}
44- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
45- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
46- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
47- end
48-
49- function scalar_nlsolve_∂f_∂p (f, u, p)
50- ff = p isa Number ? ForwardDiff. derivative :
51- (u isa Number ? ForwardDiff. gradient : ForwardDiff. jacobian)
52- return ff (Base. Fix1 (f, u), p)
36+ @inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
37+ if isinplace (prob)
38+ __f = p -> begin
39+ du = similar (u, promote_type (eltype (u), eltype (p)))
40+ f (du, u, p)
41+ return du
42+ end
43+ else
44+ __f = Base. Fix1 (f, u)
45+ end
46+ if p isa Number
47+ return __reshape (ForwardDiff. derivative (__f, p), :, 1 )
48+ elseif u isa Number
49+ return __reshape (ForwardDiff. gradient (__f, p), 1 , :)
50+ else
51+ return ForwardDiff. jacobian (__f, p)
52+ end
5353end
5454
55- function scalar_nlsolve_∂f_∂u (f, u, p)
56- ff = u isa Number ? ForwardDiff. derivative : ForwardDiff. jacobian
57- return ff (Base. Fix2 (f, p), u)
55+ @inline function __nlsolve_∂f_∂u (prob, f:: F , u, p) where {F}
56+ if isinplace (prob)
57+ du = similar (u)
58+ __f = (du, u) -> f (du, u, p)
59+ ForwardDiff. jacobian (__f, du, u)
60+ else
61+ __f = Base. Fix2 (f, p)
62+ if u isa Number
63+ return ForwardDiff. derivative (__f, u)
64+ else
65+ return ForwardDiff. jacobian (__f, u)
66+ end
67+ end
5868end
5969
60- function scalar_nlsolve_dual_soln (u:: Number , partials,
70+ @inline function __nlsolve_dual_soln (u:: Number , partials,
6171 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
6272 return Dual {T, V, P} (u, partials)
6373end
6474
65- function scalar_nlsolve_dual_soln (u:: AbstractArray , partials,
75+ @inline function __nlsolve_dual_soln (u:: AbstractArray , partials,
6676 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
67- return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, partials))
77+ _partials = _restructure (u, partials)
78+ return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, _partials))
6879end
6980
7081# avoid ambiguities
7182for Alg in [Bisection]
7283 @eval function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
7384 <: Dual{T, V, P} }, alg:: $Alg , args... ; kwargs... ) where {uType, iip, T, V, P}
74- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
75- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
85+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
86+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
7687 return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
7788 left = Dual {T, V, P} (sol. left, partials),
7889 right = Dual {T, V, P} (sol. right, partials))
7990 end
8091 @eval function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
8192 <: AbstractArray{<:Dual{T, V, P}} }, alg:: $Alg , args... ;
8293 kwargs... ) where {uType, iip, T, V, P}
83- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
84- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
94+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
95+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
8596 return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
8697 left = Dual {T, V, P} (sol. left, partials),
8798 right = Dual {T, V, P} (sol. right, partials))
0 commit comments