@@ -44,26 +44,56 @@ function Base.show(io::IO, alg::NonlinearSolvePolyAlgorithm{pType, N}) where {pT
4444 end
4545end
4646
47- @concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} < :
48- AbstractNonlinearSolveCache{iip, false }
47+ @concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit } < :
48+ AbstractNonlinearSolveCache{iip, timeit }
4949 caches
5050 alg
51+ best:: Int
5152 current:: Int
53+ nsteps:: Int
54+ total_time:: Float64
55+ maxtime
56+ retcode:: ReturnCode.T
57+ force_stop:: Bool
58+ maxiters:: Int
59+ end
60+
61+ function Base. show (
62+ io:: IO , cache:: NonlinearSolvePolyAlgorithmCache{pType, N} ) where {pType, N}
63+ problem_kind = ifelse (pType == :NLS , " NonlinearProblem" , " NonlinearLeastSquaresProblem" )
64+ println (io, " NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms" )
65+ best_alg = ifelse (cache. best == - 1 , " nothing" , cache. best)
66+ println (io, " Best algorithm: $(best_alg) " )
67+ println (io, " Current algorithm: $(cache. current) " )
68+ println (io, " nsteps: $(cache. nsteps) " )
69+ println (io, " retcode: $(cache. retcode) " )
70+ __show_cache (io, cache. caches[cache. current], 0 )
5271end
5372
5473function reinit_cache! (cache:: NonlinearSolvePolyAlgorithmCache , args... ; kwargs... )
5574 foreach (c -> reinit_cache! (c, args... ; kwargs... ), cache. caches)
5675 cache. current = 1
76+ cache. nsteps = 0
77+ cache. total_time = 0.0
5778end
5879
5980for (probType, pType) in ((:NonlinearProblem , :NLS ), (:NonlinearLeastSquaresProblem , :NLLS ))
6081 algType = NonlinearSolvePolyAlgorithm{pType}
6182 @eval begin
62- function SciMLBase. __init (
63- prob:: $probType , alg:: $algType{N} , args... ; kwargs... ) where {N}
64- return NonlinearSolvePolyAlgorithmCache {isinplace(prob), N} (
65- map (solver -> SciMLBase. __init (prob, solver, args... ; kwargs... ), alg. algs),
66- alg, 1 )
83+ function SciMLBase. __init (prob:: $probType , alg:: $algType{N} , args... ;
84+ maxtime = nothing , maxiters = 1000 , kwargs... ) where {N}
85+ return NonlinearSolvePolyAlgorithmCache {isinplace(prob), N, maxtime !== nothing} (
86+ map (solver -> SciMLBase. __init (prob, solver, args... ; maxtime, kwargs... ),
87+ alg. algs),
88+ alg,
89+ - 1 ,
90+ 1 ,
91+ 0 ,
92+ 0.0 ,
93+ maxtime,
94+ ReturnCode. Default,
95+ false ,
96+ maxiters)
6797 end
6898 end
6999end
89119 fu = get_fu ($ (cache_syms[i]))
90120 return SciMLBase. build_solution (
91121 $ (sol_syms[i]). prob, cache. alg, u, fu;
92- retcode = ReturnCode . Success , stats,
122+ retcode = $ (sol_syms[i]) . retcode , stats,
93123 original = $ (sol_syms[i]), trace = $ (sol_syms[i]). trace)
94124 end
95125 cache. current = $ (i + 1 )
@@ -103,12 +133,11 @@ end
103133 end
104134 push! (calls,
105135 quote
106- retcode = ReturnCode. MaxIters
107-
108136 fus = tuple ($ (Tuple (resids)... ))
109137 minfu, idx = __findmin (cache. caches[1 ]. internalnorm, fus)
110138 stats = cache. caches[idx]. stats
111- u = cache. caches[idx]. u
139+ u = get_u (cache. caches[idx])
140+ retcode = cache. caches[idx]. retcode
112141
113142 return SciMLBase. build_solution (cache. caches[idx]. prob, cache. alg, u, fus[idx];
114143 retcode, stats, cache. caches[idx]. trace)
117146 return Expr (:block , calls... )
118147end
119148
149+ @generated function __step! (
150+ cache:: NonlinearSolvePolyAlgorithmCache{iip, N} , args... ; kwargs... ) where {iip, N}
151+ calls = []
152+ cache_syms = [gensym (" cache" ) for i in 1 : N]
153+ for i in 1 : N
154+ push! (calls,
155+ quote
156+ $ (cache_syms[i]) = cache. caches[$ (i)]
157+ if $ (i) == cache. current
158+ __step! ($ (cache_syms[i]), args... ; kwargs... )
159+ $ (cache_syms[i]). nsteps += 1
160+ if ! not_terminated ($ (cache_syms[i]))
161+ if SciMLBase. successful_retcode ($ (cache_syms[i]). retcode)
162+ cache. best = $ (i)
163+ cache. force_stop = true
164+ cache. retcode = $ (cache_syms[i]). retcode
165+ else
166+ cache. current = $ (i + 1 )
167+ end
168+ end
169+ return
170+ end
171+ end )
172+ end
173+
174+ push! (calls,
175+ quote
176+ if ! (1 ≤ cache. current ≤ length (cache. caches))
177+ minfu, idx = __findmin (first (cache. caches). internalnorm, cache. caches)
178+ cache. best = idx
179+ cache. retcode = cache. caches[cache. best]. retcode
180+ cache. force_stop = true
181+ return
182+ end
183+ end )
184+
185+ return Expr (:block , calls... )
186+ end
187+
120188for (probType, pType) in ((:NonlinearProblem , :NLS ), (:NonlinearLeastSquaresProblem , :NLLS ))
121189 algType = NonlinearSolvePolyAlgorithm{pType}
122190 @eval begin
0 commit comments