1- function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
2- f = prob. f
1+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
2+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
3+ alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P, iip}
4+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
5+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
6+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
7+ sol. original)
8+ end
9+
10+ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
11+ @eval begin
12+ function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
13+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
14+ alg:: $ (algType), args... ; kwargs... ) where {uType, T, V, P, iip}
15+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
16+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
17+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
18+ sol. stats, sol. original, left = Dual {T, V, P} (sol. left, partials),
19+ right = Dual {T, V, P} (sol. right, partials))
20+ end
21+ end
22+ end
23+
24+ function __nlsolve_ad (prob, alg, args... ; kwargs... )
325 p = value (prob. p)
426 if prob isa IntervalNonlinearProblem
527 tspan = value .(prob. tspan)
6- newprob = IntervalNonlinearProblem (f, tspan, p; prob. kwargs... )
28+ newprob = IntervalNonlinearProblem (prob . f, tspan, p; prob. kwargs... )
729 else
830 u0 = value (prob. u0)
9- newprob = NonlinearProblem (f, u0, p; prob. kwargs... )
31+ newprob = NonlinearProblem (prob . f, u0, p; prob. kwargs... )
1032 end
1133
1234 sol = solve (newprob, alg, args... ; kwargs... )
1335
1436 uu = sol. u
15- f_p = scalar_nlsolve_ ∂f_∂p (f, uu, p)
16- f_x = scalar_nlsolve_ ∂f_∂u (f, uu, p)
37+ f_p = __nlsolve_ ∂f_∂p (prob, prob . f, uu, p)
38+ f_x = __nlsolve_ ∂f_∂u (prob, prob . f, uu, p)
1739
18- z_arr = - inv ( f_x) * f_p
40+ z_arr = - f_x \ f_p
1941
2042 pp = prob. p
2143 sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
@@ -30,60 +52,47 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3052 return sol, partials
3153end
3254
33- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
34- false , <: Dual{T, V, P} }, alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ;
35- kwargs... ) where {T, V, P}
36- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
37- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
38- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
39- end
40-
41- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
42- false , <: AbstractArray{<:Dual{T, V, P}} },
43- alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P}
44- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
45- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
46- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
47- end
48-
49- function scalar_nlsolve_∂f_∂p (f, u, p)
50- ff = p isa Number ? ForwardDiff. derivative :
51- (u isa Number ? ForwardDiff. gradient : ForwardDiff. jacobian)
52- return ff (Base. Fix1 (f, u), p)
55+ @inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
56+ if isinplace (prob)
57+ __f = p -> begin
58+ du = similar (u, promote_type (eltype (u), eltype (p)))
59+ f (du, u, p)
60+ return du
61+ end
62+ else
63+ __f = Base. Fix1 (f, u)
64+ end
65+ if p isa Number
66+ return __reshape (ForwardDiff. derivative (__f, p), :, 1 )
67+ elseif u isa Number
68+ return __reshape (ForwardDiff. gradient (__f, p), 1 , :)
69+ else
70+ return ForwardDiff. jacobian (__f, p)
71+ end
5372end
5473
55- function scalar_nlsolve_∂f_∂u (f, u, p)
56- ff = u isa Number ? ForwardDiff. derivative : ForwardDiff. jacobian
57- return ff (Base. Fix2 (f, p), u)
74+ @inline function __nlsolve_∂f_∂u (prob, f:: F , u, p) where {F}
75+ if isinplace (prob)
76+ du = similar (u)
77+ __f = (du, u) -> f (du, u, p)
78+ ForwardDiff. jacobian (__f, du, u)
79+ else
80+ __f = Base. Fix2 (f, p)
81+ if u isa Number
82+ return ForwardDiff. derivative (__f, u)
83+ else
84+ return ForwardDiff. jacobian (__f, u)
85+ end
86+ end
5887end
5988
60- function scalar_nlsolve_dual_soln (u:: Number , partials,
89+ @inline function __nlsolve_dual_soln (u:: Number , partials,
6190 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
6291 return Dual {T, V, P} (u, partials)
6392end
6493
65- function scalar_nlsolve_dual_soln (u:: AbstractArray , partials,
94+ @inline function __nlsolve_dual_soln (u:: AbstractArray , partials,
6695 :: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
67- return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, partials))
68- end
69-
70- # avoid ambiguities
71- for Alg in [Bisection]
72- @eval function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
73- <: Dual{T, V, P} }, alg:: $Alg , args... ; kwargs... ) where {uType, iip, T, V, P}
74- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
75- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
76- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
77- left = Dual {T, V, P} (sol. left, partials),
78- right = Dual {T, V, P} (sol. right, partials))
79- end
80- @eval function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
81- <: AbstractArray{<:Dual{T, V, P}} }, alg:: $Alg , args... ;
82- kwargs... ) where {uType, iip, T, V, P}
83- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
84- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
85- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
86- left = Dual {T, V, P} (sol. left, partials),
87- right = Dual {T, V, P} (sol. right, partials))
88- end
96+ _partials = _restructure (u, partials)
97+ return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, _partials))
8998end
0 commit comments