1- function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
2- f = prob. f
1+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
2+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
3+ alg:: Union{Nothing, AbstractNonlinearAlgorithm} , args... ;
4+ kwargs... ) where {T, V, P, iip}
5+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
6+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
7+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
8+ end
9+
10+ # Differentiate Out-of-Place Nonlinear Root Finding Problems
11+ function __nlsolve_ad (prob:: NonlinearProblem{uType, false} , alg, args... ;
12+ kwargs... ) where {uType}
313 p = value (prob. p)
4- u0 = value (prob. u0)
5- newprob = NonlinearProblem (f, u0, p; prob. kwargs... )
14+ newprob = NonlinearProblem (prob. f, value (prob. u0), p; prob. kwargs... )
615
716 sol = solve (newprob, alg, args... ; kwargs... )
817
918 uu = sol. u
10- f_p = scalar_nlsolve_ ∂f_∂p (f, uu, p)
11- f_x = scalar_nlsolve_ ∂f_∂u (f, uu, p)
19+ f_p = __nlsolve_ ∂f_∂p (prob . f, uu, p)
20+ f_x = __nlsolve_ ∂f_∂u (prob . f, uu, p)
1221
13- z_arr = - inv ( f_x) * f_p
22+ z_arr = - f_x \ f_p
1423
1524 pp = prob. p
1625 sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
@@ -25,39 +34,33 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2534 return sol, partials
2635end
2736
28- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
29- false , <: Dual{T, V, P} }, alg:: AbstractNonlinearSolveAlgorithm , args... ;
30- kwargs... ) where {T, V, P}
31- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
32- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
33- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
34- end
35-
36- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
37- false , <: AbstractArray{<:Dual{T, V, P}} }, alg:: AbstractNonlinearSolveAlgorithm ,
38- args... ; kwargs... ) where {T, V, P}
39- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
40- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
41- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
42- end
43-
44- function scalar_nlsolve_∂f_∂p (f, u, p)
45- ff = p isa Number ? ForwardDiff. derivative :
46- (u isa Number ? ForwardDiff. gradient : ForwardDiff. jacobian)
47- return ff (Base. Fix1 (f, u), p)
37+ @inline function __nlsolve_∂f_∂p (f:: F , u, p) where {F}
38+ __f = Base. Fix1 (f, u)
39+ if p isa Number
40+ return __reshape (ForwardDiff. derivative (__f, p), :, 1 )
41+ elseif u isa Number
42+ return __reshape (ForwardDiff. gradient (__f, p), 1 , :)
43+ else
44+ return ForwardDiff. jacobian (__f, p)
45+ end
4846end
4947
50- function scalar_nlsolve_∂f_∂u (f, u, p)
51- ff = u isa Number ? ForwardDiff. derivative : ForwardDiff. jacobian
52- return ff (Base. Fix2 (f, p), u)
48+ @inline function __nlsolve_∂f_∂u (f:: F , u, p) where {F}
49+ __f = Base. Fix2 (f, p)
50+ if u isa Number
51+ return ForwardDiff. derivative (__f, u)
52+ else
53+ return ForwardDiff. jacobian (__f, u)
54+ end
5355end
5456
55- function scalar_nlsolve_dual_soln (u:: Number , partials,
57+ @inline function __nlsolve_dual_soln (u:: Number , partials,
5658 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
5759 return Dual {T, V, P} (u, partials)
5860end
5961
60- function scalar_nlsolve_dual_soln (u:: AbstractArray , partials,
62+ @inline function __nlsolve_dual_soln (u:: AbstractArray , partials,
6163 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
62- return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, partials))
64+ _partials = _restructure (u, partials)
65+ return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, _partials))
6366end
0 commit comments