1- const DualNonlinearProblem = NonlinearProblem{<: Union{Number, <:AbstractArray} , iip,
2- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} } where {iip, T, V, P}
1+ const DualNonlinearProblem = NonlinearProblem{
2+ <: Union{Number, <:AbstractArray} , iip,
3+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
4+ } where {iip, T, V, P}
35const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
46 <: Union{Number, <:AbstractArray} , iip,
5- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} } where {iip, T, V, P}
7+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
8+ } where {iip, T, V, P}
69const DualAbstractNonlinearProblem = Union{
7- DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
10+ DualNonlinearProblem, DualNonlinearLeastSquaresProblem
11+ }
812
913for algType in ALL_SOLVER_TYPES
1014 @eval function SciMLBase. __solve (
11- prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs... )
12- sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
13- dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
15+ prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs...
16+ )
17+ sol, partials = NonlinearSolveBase. nonlinearsolve_forwarddiff_solve (
18+ prob, alg, args... ; kwargs...
19+ )
20+ dual_soln = NonlinearSolveBase. nonlinearsolve_dual_solution (sol. u, partials, prob. p)
1421 return SciMLBase. build_solution (
15- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
22+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
23+ )
1624 end
1725end
1826
19- @concrete mutable struct NonlinearSolveForwardDiffCache
27+ @concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
2028 cache
2129 prob
2230 alg
2533 partials_p
2634end
2735
28- @internal_caches NonlinearSolveForwardDiffCache :cache
29-
30- function reinit_cache! (cache:: NonlinearSolveForwardDiffCache ;
31- p = cache. p, u0 = get_u (cache. cache), kwargs... )
32- inner_cache = reinit_cache! (cache. cache; p = __value (p), u0 = __value (u0), kwargs... )
36+ function InternalAPI. reinit! (
37+ cache:: NonlinearSolveForwardDiffCache , args... ;
38+ p = cache. p, u0 = NonlinearSolveBase. get_u (cache. cache), kwargs...
39+ )
40+ inner_cache = InternalAPI. reinit! (
41+ cache. cache; p = nodual_value (p), u0 = nodual_value (u0), kwargs...
42+ )
3343 cache. cache = inner_cache
3444 cache. p = p
35- cache. values_p = __value (p)
45+ cache. values_p = nodual_value (p)
3646 cache. partials_p = ForwardDiff. partials (p)
3747 return cache
3848end
3949
4050for algType in ALL_SOLVER_TYPES
51+ # XXX : Extend to DualNonlinearLeastSquaresProblem
4152 @eval function SciMLBase. __init (
42- prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs... )
43- p = __value (prob. p)
44- newprob = NonlinearProblem (prob. f, __value (prob. u0), p; prob. kwargs... )
53+ prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs...
54+ )
55+ p = nodual_value (prob. p)
56+ newprob = SciMLBase. remake (prob; u0 = nodual_value (prob. u0), p)
4557 cache = init (newprob, alg, args... ; kwargs... )
4658 return NonlinearSolveForwardDiffCache (
47- cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p))
59+ cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p)
60+ )
4861 end
4962end
5063
51- function SciMLBase . solve! (cache:: NonlinearSolveForwardDiffCache )
64+ function CommonSolve . solve! (cache:: NonlinearSolveForwardDiffCache )
5265 sol = solve! (cache. cache)
5366 prob = cache. prob
5467
5568 uu = sol. u
56- Jₚ = nonlinearsolve_∂f_∂p (prob, prob. f, uu, cache. values_p)
57- Jᵤ = nonlinearsolve_∂f_∂u (prob, prob. f, uu, cache. values_p)
69+ Jₚ = NonlinearSolveBase . nonlinearsolve_∂f_∂p (prob, prob. f, uu, cache. values_p)
70+ Jᵤ = NonlinearSolveBase . nonlinearsolve_∂f_∂u (prob, prob. f, uu, cache. values_p)
5871
5972 z_arr = - Jᵤ \ Jₚ
6073
@@ -65,11 +78,12 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
6578 partials = sum (sumfun, zip (eachcol (z_arr), cache. p))
6679 end
6780
68- dual_soln = nonlinearsolve_dual_solution (sol. u, partials, cache. p)
81+ dual_soln = NonlinearSolveBase . nonlinearsolve_dual_solution (sol. u, partials, cache. p)
6982 return SciMLBase. build_solution (
70- prob, cache. alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
83+ prob, cache. alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
84+ )
7185end
7286
73- @inline __value (x) = x
74- @inline __value (x:: Dual ) = ForwardDiff. value (x)
75- @inline __value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
87+ nodual_value (x) = x
88+ nodual_value (x:: Dual ) = ForwardDiff. value (x)
89+ nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
0 commit comments