1- function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
2- f = prob. f
1+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
2+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
3+ alg:: Union{Nothing, AbstractNonlinearAlgorithm} , args... ;
4+ kwargs... ) where {T, V, P, iip}
5+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
6+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
7+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
8+ sol. original)
9+ end
10+
11+ @concrete mutable struct NonlinearSolveForwardDiffCache
12+ cache
13+ prob
14+ alg
15+ p
16+ values_p
17+ partials_p
18+ end
19+
20+ @inline function __has_duals (:: Union {<: Dual{T, V, P} ,
21+ <: AbstractArray{<:Dual{T, V, P}} }) where {T, V, P}
22+ return true
23+ end
24+ @inline __has_duals (:: Any ) = false
25+
26+ function SciMLBase. reinit! (cache:: NonlinearSolveForwardDiffCache ; p = cache. p,
27+ u0 = get_u (cache. cache), kwargs... )
28+ inner_cache = SciMLBase. reinit! (cache. cache; p = value (p), u0 = value (u0), kwargs... )
29+ cache. cache = inner_cache
30+ cache. p = p
31+ cache. values_p = value (p)
32+ cache. partials_p = ForwardDiff. partials (p)
33+ return cache
34+ end
35+
36+ function SciMLBase. init (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
37+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
38+ alg:: Union{Nothing, AbstractNonlinearAlgorithm} , args... ;
39+ kwargs... ) where {T, V, P, iip}
340 p = value (prob. p)
4- u0 = value (prob. u0)
5- newprob = NonlinearProblem (f, u0, p; prob. kwargs... )
41+ newprob = NonlinearProblem (prob. f, value (prob. u0), p; prob. kwargs... )
42+ cache = init (newprob, alg, args... ; kwargs... )
43+ return NonlinearSolveForwardDiffCache (cache, newprob, alg, prob. p, p,
44+ ForwardDiff. partials (prob. p))
45+ end
46+
47+ function SciMLBase. solve! (cache:: NonlinearSolveForwardDiffCache )
48+ sol = solve! (cache. cache)
49+ prob = cache. prob
50+
51+ uu = sol. u
52+ f_p = __nlsolve_∂f_∂p (prob, prob. f, uu, cache. values_p)
53+ f_x = __nlsolve_∂f_∂u (prob, prob. f, uu, cache. values_p)
54+
55+ z_arr = - f_x \ f_p
56+
57+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
58+ if cache. p isa Number
59+ partials = sumfun ((z_arr, cache. p))
60+ else
61+ partials = sum (sumfun, zip (eachcol (z_arr), cache. p))
62+ end
63+
64+ dual_soln = __nlsolve_dual_soln (sol. u, partials, cache. p)
65+ return SciMLBase. build_solution (prob, cache. alg, dual_soln, sol. resid; sol. retcode,
66+ sol. stats, sol. original)
67+ end
68+
69+ function __nlsolve_ad (prob:: NonlinearProblem{uType, iip} , alg, args... ;
70+ kwargs... ) where {uType, iip}
71+ p = value (prob. p)
72+ newprob = NonlinearProblem (prob. f, value (prob. u0), p; prob. kwargs... )
673
774 sol = solve (newprob, alg, args... ; kwargs... )
875
976 uu = sol. u
10- f_p = scalar_nlsolve_ ∂f_∂p (f, uu, p)
11- f_x = scalar_nlsolve_ ∂f_∂u (f, uu, p)
77+ f_p = __nlsolve_ ∂f_∂p (prob, prob . f, uu, p)
78+ f_x = __nlsolve_ ∂f_∂u (prob, prob . f, uu, p)
1279
13- z_arr = - inv ( f_x) * f_p
80+ z_arr = - f_x \ f_p
1481
1582 pp = prob. p
1683 sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
@@ -25,39 +92,47 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2592 return sol, partials
2693end
2794
28- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
29- false , <: Dual{T, V, P} }, alg:: AbstractNonlinearSolveAlgorithm , args... ;
30- kwargs... ) where {T, V, P}
31- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
32- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
33- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
34- end
35-
36- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
37- false , <: AbstractArray{<:Dual{T, V, P}} }, alg:: AbstractNonlinearSolveAlgorithm ,
38- args... ; kwargs... ) where {T, V, P}
39- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
40- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
41- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
42- end
43-
44- function scalar_nlsolve_∂f_∂p (f, u, p)
45- ff = p isa Number ? ForwardDiff. derivative :
46- (u isa Number ? ForwardDiff. gradient : ForwardDiff. jacobian)
47- return ff (Base. Fix1 (f, u), p)
95+ @inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
96+ if isinplace (prob)
97+ __f = p -> begin
98+ du = similar (u, promote_type (eltype (u), eltype (p)))
99+ f (du, u, p)
100+ return du
101+ end
102+ else
103+ __f = Base. Fix1 (f, u)
104+ end
105+ if p isa Number
106+ return __reshape (ForwardDiff. derivative (__f, p), :, 1 )
107+ elseif u isa Number
108+ return __reshape (ForwardDiff. gradient (__f, p), 1 , :)
109+ else
110+ return ForwardDiff. jacobian (__f, p)
111+ end
48112end
49113
50- function scalar_nlsolve_∂f_∂u (f, u, p)
51- ff = u isa Number ? ForwardDiff. derivative : ForwardDiff. jacobian
52- return ff (Base. Fix2 (f, p), u)
114+ @inline function __nlsolve_∂f_∂u (prob, f:: F , u, p) where {F}
115+ if isinplace (prob)
116+ du = similar (u)
117+ __f = (du, u) -> f (du, u, p)
118+ ForwardDiff. jacobian (__f, du, u)
119+ else
120+ __f = Base. Fix2 (f, p)
121+ if u isa Number
122+ return ForwardDiff. derivative (__f, u)
123+ else
124+ return ForwardDiff. jacobian (__f, u)
125+ end
126+ end
53127end
54128
55- function scalar_nlsolve_dual_soln (u:: Number , partials,
129+ @inline function __nlsolve_dual_soln (u:: Number , partials,
56130 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
57131 return Dual {T, V, P} (u, partials)
58132end
59133
60- function scalar_nlsolve_dual_soln (u:: AbstractArray , partials,
134+ @inline function __nlsolve_dual_soln (u:: AbstractArray , partials,
61135 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
62- return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, partials))
136+ _partials = _restructure (u, partials)
137+ return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, _partials))
63138end
0 commit comments