@@ -44,26 +44,35 @@ function Base.show(io::IO, alg::NonlinearSolvePolyAlgorithm{pType, N}) where {pT
4444 end
4545end
4646
47- @concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} < :
48- AbstractNonlinearSolveCache{iip, false }
47+ @concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit } < :
48+ AbstractNonlinearSolveCache{iip, timeit }
4949 caches
5050 alg
51+ best:: Int
5152 current:: Int
53+ nsteps:: Int
54+ total_time:: Float64
55+ maxtime
56+ retcode:: ReturnCode.T
57+ force_stop:: Bool
5258end
5359
5460function reinit_cache! (cache:: NonlinearSolvePolyAlgorithmCache , args... ; kwargs... )
5561 foreach (c -> reinit_cache! (c, args... ; kwargs... ), cache. caches)
5662 cache. current = 1
63+ cache. nsteps = 0
64+ cache. total_time = 0.0
5765end
5866
5967for (probType, pType) in ((:NonlinearProblem , :NLS ), (:NonlinearLeastSquaresProblem , :NLLS ))
6068 algType = NonlinearSolvePolyAlgorithm{pType}
6169 @eval begin
62- function SciMLBase. __init (
63- prob:: $probType , alg:: $algType{N} , args... ; kwargs... ) where {N}
64- return NonlinearSolvePolyAlgorithmCache {isinplace(prob), N} (
65- map (solver -> SciMLBase. __init (prob, solver, args... ; kwargs... ), alg. algs),
66- alg, 1 )
70+ function SciMLBase. __init (prob:: $probType , alg:: $algType{N} , args... ;
71+ maxtime = nothing , kwargs... ) where {N}
72+ return NonlinearSolvePolyAlgorithmCache {isinplace(prob), N, maxtime !== nothing} (
73+ map (solver -> SciMLBase. __init (prob, solver, args... ; maxtime, kwargs... ),
74+ alg. algs), alg, - 1 , 1 , 0 , 0.0 , maxtime,
75+ ReturnCode. Default, false )
6776 end
6877 end
6978end
8796 stats = $ (sol_syms[i]). stats
8897 u = $ (sol_syms[i]). u
8998 fu = get_fu ($ (cache_syms[i]))
90- return SciMLBase. build_solution (
91- $ (sol_syms[i]). prob, cache. alg, u, fu;
92- retcode = ReturnCode. Success, stats,
99+ return SciMLBase. build_solution ($ (sol_syms[i]). prob, cache. alg, u,
100+ fu; retcode = $ (sol_syms[i]). retcode, stats,
93101 original = $ (sol_syms[i]), trace = $ (sol_syms[i]). trace)
94102 end
95103 cache. current = $ (i + 1 )
@@ -103,12 +111,11 @@ end
103111 end
104112 push! (calls,
105113 quote
106- retcode = ReturnCode. MaxIters
107-
108114 fus = tuple ($ (Tuple (resids)... ))
109115 minfu, idx = __findmin (cache. caches[1 ]. internalnorm, fus)
110116 stats = cache. caches[idx]. stats
111- u = cache. caches[idx]. u
117+ u = get_u (cache. caches[idx])
118+ retcode = cache. caches[idx]. retcode
112119
113120 return SciMLBase. build_solution (cache. caches[idx]. prob, cache. alg, u, fus[idx];
114121 retcode, stats, cache. caches[idx]. trace)
117124 return Expr (:block , calls... )
118125end
119126
127+ @generated function __step! (
128+ cache:: NonlinearSolvePolyAlgorithmCache{iip, N} , args... ; kwargs... ) where {iip, N}
129+ calls = []
130+ cache_syms = [gensym (" cache" ) for i in 1 : N]
131+ for i in 1 : N
132+ push! (calls,
133+ quote
134+ $ (cache_syms[i]) = cache. caches[$ (i)]
135+ if $ (i) == cache. current
136+ __step! ($ (cache_syms[i]), args... ; kwargs... )
137+ if ! not_terminated ($ (cache_syms[i]))
138+ if SciMLBase. successful_retcode ($ (cache_syms[i]). retcode)
139+ cache. best = $ (i)
140+ cache. force_stop = true
141+ cache. retcode = $ (cache_syms[i]). retcode
142+ else
143+ cache. current = $ (i + 1 )
144+ end
145+ end
146+ return
147+ end
148+ end )
149+ end
150+
151+ push! (calls,
152+ quote
153+ if ! (1 ≤ cache. current ≤ length (cache. caches))
154+ minfu, idx = __findmin (first (cache. caches). internalnorm, cache. caches)
155+ cache. best = idx
156+ cache. retcode = cache. caches[cache. best]. retcode
157+ cache. force_stop = true
158+ return
159+ end
160+ end
161+ )
162+
163+ return Expr (:block , calls... )
164+ end
165+
120166for (probType, pType) in ((:NonlinearProblem , :NLS ), (:NonlinearLeastSquaresProblem , :NLLS ))
121167 algType = NonlinearSolvePolyAlgorithm{pType}
122168 @eval begin
0 commit comments