6565 force_stop:: Bool
6666 maxiters:: Int
6767 internalnorm
68+ u0
69+ u0_aliased
70+ alias_u0:: Bool
6871end
6972
7073function Base. show (
@@ -91,11 +94,24 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
9194 @eval begin
9295 function SciMLBase. __init (
9396 prob:: $probType , alg:: $algType{N} , args... ; maxtime = nothing ,
94- maxiters = 1000 , internalnorm = DEFAULT_NORM, kwargs... ) where {N}
97+ maxiters = 1000 , internalnorm = DEFAULT_NORM,
98+ alias_u0 = false , verbose = true , kwargs... ) where {N}
99+ if (alias_u0 && ! ismutable (prob. u0))
100+ verbose && @warn " `alias_u0` has been set to `true`, but `u0` is \
101+ immutable (checked using `ArrayInterface.ismutable`)."
102+ alias_u0 = false # If immutable don't care about aliasing
103+ end
104+ u0 = prob. u0
105+ if alias_u0
106+ u0_aliased = copy (u0)
107+ else
108+ u0_aliased = u0 # Irrelevant
109+ end
110+ alias_u0 && (prob = remake (prob; u0 = u0_aliased))
95111 return NonlinearSolvePolyAlgorithmCache {isinplace(prob), N, maxtime !== nothing} (
96112 map (
97- solver -> SciMLBase. __init (
98- prob, solver, args ... ; maxtime, internalnorm , kwargs... ),
113+ solver -> SciMLBase. __init (prob, solver, args ... ; maxtime,
114+ internalnorm, alias_u0, verbose , kwargs... ),
99115 alg. algs),
100116 alg,
101117 - 1 ,
@@ -106,7 +122,10 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
106122 ReturnCode. Default,
107123 false ,
108124 maxiters,
109- internalnorm)
125+ internalnorm,
126+ u0,
127+ u0_aliased,
128+ alias_u0)
110129 end
111130 end
112131end
@@ -120,20 +139,30 @@ end
120139
121140 cache_syms = [gensym (" cache" ) for i in 1 : N]
122141 sol_syms = [gensym (" sol" ) for i in 1 : N]
142+ u_result_syms = [gensym (" u_result" ) for i in 1 : N]
123143 for i in 1 : N
124144 push! (calls,
125145 quote
126146 $ (cache_syms[i]) = cache. caches[$ (i)]
127147 if $ (i) == cache. current
148+ cache. alias_u0 && copyto! (cache. u0_aliased, cache. u0)
128149 $ (sol_syms[i]) = SciMLBase. solve! ($ (cache_syms[i]))
129150 if SciMLBase. successful_retcode ($ (sol_syms[i]))
130151 stats = $ (sol_syms[i]). stats
131- u = $ (sol_syms[i]). u
152+ if cache. alias_u0
153+ copyto! (cache. u0, $ (sol_syms[i]). u)
154+ $ (u_result_syms[i]) = cache. u0
155+ else
156+ $ (u_result_syms[i]) = $ (sol_syms[i]). u
157+ end
132158 fu = get_fu ($ (cache_syms[i]))
133159 return SciMLBase. build_solution (
134- $ (sol_syms[i]). prob, cache. alg, u, fu;
135- retcode = $ (sol_syms[i]). retcode, stats,
160+ $ (sol_syms[i]). prob, cache. alg, $ (u_result_syms[i]),
161+ fu; retcode = $ (sol_syms[i]). retcode, stats,
136162 original = $ (sol_syms[i]), trace = $ (sol_syms[i]). trace)
163+ elseif cache. alias_u0
164+ # For safety we need to maintain a copy of the solution
165+ $ (u_result_syms[i]) = copy ($ (sol_syms[i]). u)
137166 end
138167 cache. current = $ (i + 1 )
139168 end
@@ -144,14 +173,29 @@ end
144173 for (sym, resid) in zip (cache_syms, resids)
145174 push! (calls, :($ (resid) = @isdefined ($ (sym)) ? get_fu ($ (sym)) : nothing ))
146175 end
176+ push! (calls, quote
177+ fus = tuple ($ (Tuple (resids)... ))
178+ minfu, idx = __findmin (cache. internalnorm, fus)
179+ stats = __compile_stats (cache. caches[idx])
180+ end )
181+ for i in 1 : N
182+ push! (calls, quote
183+ if idx == $ (i)
184+ if cache. alias_u0
185+ u = $ (u_result_syms[i])
186+ else
187+ u = get_u (cache. caches[$ i])
188+ end
189+ end
190+ end )
191+ end
147192 push! (calls,
148193 quote
149- fus = tuple ($ (Tuple (resids)... ))
150- minfu, idx = __findmin (cache. internalnorm, fus)
151- stats = __compile_stats (cache. caches[idx])
152- u = get_u (cache. caches[idx])
153194 retcode = cache. caches[idx]. retcode
154-
195+ if cache. alias_u0
196+ copyto! (cache. u0, u)
197+ u = cache. u0
198+ end
155199 return SciMLBase. build_solution (cache. caches[idx]. prob, cache. alg, u, fus[idx];
156200 retcode, stats, cache. caches[idx]. trace)
157201 end )
@@ -200,22 +244,52 @@ end
200244for (probType, pType) in ((:NonlinearProblem , :NLS ), (:NonlinearLeastSquaresProblem , :NLLS ))
201245 algType = NonlinearSolvePolyAlgorithm{pType}
202246 @eval begin
203- @generated function SciMLBase. __solve (
204- prob:: $probType , alg:: $algType{N} , args... ; kwargs... ) where {N}
205- calls = [:(current = alg. start_index)]
247+ @generated function SciMLBase. __solve (prob:: $probType , alg:: $algType{N} , args... ;
248+ alias_u0 = false , verbose = true , kwargs... ) where {N}
206249 sol_syms = [gensym (" sol" ) for _ in 1 : N]
250+ prob_syms = [gensym (" prob" ) for _ in 1 : N]
251+ u_result_syms = [gensym (" u_result" ) for _ in 1 : N]
252+ calls = [quote
253+ current = alg. start_index
254+ if (alias_u0 && ! ismutable (prob. u0))
255+ verbose && @warn " `alias_u0` has been set to `true`, but `u0` is \
256+ immutable (checked using `ArrayInterface.ismutable`)."
257+ alias_u0 = false # If immutable don't care about aliasing
258+ end
259+ u0 = prob. u0
260+ if alias_u0
261+ u0_aliased = similar (u0)
262+ else
263+ u0_aliased = u0 # Irrelevant
264+ end
265+ end ]
207266 for i in 1 : N
208267 cur_sol = sol_syms[i]
209268 push! (calls,
210269 quote
211270 if current == $ i
212- $ (cur_sol) = SciMLBase. __solve (
213- prob, alg. algs[$ (i)], args... ; kwargs... )
271+ if alias_u0
272+ copyto! (u0_aliased, u0)
273+ $ (prob_syms[i]) = remake (prob; u0 = u0_aliased)
274+ else
275+ $ (prob_syms[i]) = prob
276+ end
277+ $ (cur_sol) = SciMLBase. __solve ($ (prob_syms[i]), alg. algs[$ (i)],
278+ args... ; alias_u0, verbose, kwargs... )
214279 if SciMLBase. successful_retcode ($ (cur_sol))
280+ if alias_u0
281+ copyto! (u0, $ (cur_sol). u)
282+ $ (u_result_syms[i]) = u0
283+ else
284+ $ (u_result_syms[i]) = $ (cur_sol). u
285+ end
215286 return SciMLBase. build_solution (
216- prob, alg, $ (cur_sol) . u , $ (cur_sol). resid;
287+ prob, alg, $ (u_result_syms[i]) , $ (cur_sol). resid;
217288 $ (cur_sol). retcode, $ (cur_sol). stats,
218289 original = $ (cur_sol), trace = $ (cur_sol). trace)
290+ elseif alias_u0
291+ # For safety we need to maintain a copy of the solution
292+ $ (u_result_syms[i]) = copy ($ (cur_sol). u)
219293 end
220294 current = $ (i + 1 )
221295 end
@@ -236,9 +310,16 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
236310 push! (calls,
237311 quote
238312 if idx == $ i
239- return SciMLBase. build_solution (prob, alg, $ (sol_syms[i]). u,
240- $ (sol_syms[i]). resid; $ (sol_syms[i]). retcode,
241- $ (sol_syms[i]). stats, $ (sol_syms[i]). trace)
313+ if alias_u0
314+ copyto! (u0, $ (u_result_syms[i]))
315+ $ (u_result_syms[i]) = u0
316+ else
317+ $ (u_result_syms[i]) = $ (sol_syms[i]). u
318+ end
319+ return SciMLBase. build_solution (
320+ prob, alg, $ (u_result_syms[i]), $ (sol_syms[i]). resid;
321+ $ (sol_syms[i]). retcode, $ (sol_syms[i]). stats,
322+ $ (sol_syms[i]). trace, original = $ (sol_syms[i]))
242323 end
243324 end )
244325 end
0 commit comments