|
| 1 | +# Optimized tolerance checking that avoids allocations |
| 2 | +@inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{true}) |
| 3 | + if abstol isa Number |
| 4 | + return integrator.opts.internalnorm(err, t) / abstol <= 1 |
| 5 | + else |
| 6 | + @. err = err / abstol # Safe for in-place functions |
| 7 | + return integrator.opts.internalnorm(err, t) <= 1 |
| 8 | + end |
| 9 | +end |
| 10 | + |
| 11 | +@inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{false}) |
| 12 | + if abstol isa Number |
| 13 | + return integrator.opts.internalnorm(err, t) / abstol <= 1 |
| 14 | + else |
| 15 | + return integrator.opts.internalnorm(err ./ abstol, t) <= 1 # Allocates for out-of-place |
| 16 | + end |
| 17 | +end |
| 18 | + |
1 | 19 | function default_nlsolve( |
2 | 20 | ::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false) |
3 | 21 | FastShortcutNonlinearPolyalg(; |
@@ -57,22 +75,24 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation |
57 | 75 | f(tmp, u0, p, t) |
58 | 76 | tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) |
59 | 77 |
|
60 | | - integrator.opts.internalnorm(tmp, t) <= integrator.opts.abstol && return |
| 78 | + check_dae_tolerance(integrator, tmp, integrator.opts.abstol, t, isinplace) && return |
61 | 79 |
|
62 | 80 | if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve) |
63 | 81 | # backward Euler |
64 | 82 | nlsolver = integrator.cache.nlsolver |
65 | | - oldγ, oldc, oldmethod, olddt = nlsolver.γ, nlsolver.c, nlsolver.method, |
| 83 | + oldγ, oldc, oldmethod, |
| 84 | + olddt = nlsolver.γ, nlsolver.c, nlsolver.method, |
66 | 85 | integrator.dt |
67 | 86 | nlsolver.tmp .= integrator.uprev |
68 | 87 | nlsolver.γ, nlsolver.c = 1, 1 |
69 | 88 | nlsolver.method = DIRK |
70 | 89 | integrator.dt = dt |
71 | 90 | z = nlsolve!(nlsolver, integrator, integrator.cache) |
72 | | - nlsolver.γ, nlsolver.c, nlsolver.method, integrator.dt = oldγ, oldc, oldmethod, |
| 91 | + nlsolver.γ, nlsolver.c, nlsolver.method, |
| 92 | + integrator.dt = oldγ, oldc, oldmethod, |
73 | 93 | olddt |
74 | 94 | failed = nlsolvefail(nlsolver) |
75 | | - @.. broadcast=false integrator.u=integrator.uprev + z |
| 95 | + @.. broadcast=false integrator.u=integrator.uprev+z |
76 | 96 | else |
77 | 97 |
|
78 | 98 | # _u0 should be non-dual since NonlinearSolve does not differentiate the solver |
@@ -169,22 +189,24 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation |
169 | 189 | du = f(u0, p, t) |
170 | 190 | resid = _vec(du)[algebraic_eqs] |
171 | 191 |
|
172 | | - integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return |
| 192 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return |
173 | 193 |
|
174 | 194 | if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve) |
175 | 195 | # backward Euler |
176 | 196 | nlsolver = integrator.cache.nlsolver |
177 | | - oldγ, oldc, oldmethod, olddt = nlsolver.γ, nlsolver.c, nlsolver.method, |
| 197 | + oldγ, oldc, oldmethod, |
| 198 | + olddt = nlsolver.γ, nlsolver.c, nlsolver.method, |
178 | 199 | integrator.dt |
179 | 200 | nlsolver.tmp .= integrator.uprev |
180 | 201 | nlsolver.γ, nlsolver.c = 1, 1 |
181 | 202 | nlsolver.method = DIRK |
182 | 203 | integrator.dt = dt |
183 | 204 | z = nlsolve!(nlsolver, integrator, integrator.cache) |
184 | | - nlsolver.γ, nlsolver.c, nlsolver.method, integrator.dt = oldγ, oldc, oldmethod, |
| 205 | + nlsolver.γ, nlsolver.c, nlsolver.method, |
| 206 | + integrator.dt = oldγ, oldc, oldmethod, |
185 | 207 | olddt |
186 | 208 | failed = nlsolvefail(nlsolver) |
187 | | - @.. broadcast=false integrator.u=integrator.uprev + z |
| 209 | + @.. broadcast=false integrator.u=integrator.uprev+z |
188 | 210 | else |
189 | 211 | nlequation_oop = @closure (u, _) -> begin |
190 | 212 | update_coefficients!(M, u, p, t) |
@@ -235,7 +257,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
235 | 257 | dt = t != 0 ? min(t / 1000, dtmax / 10) : dtmax / 10 # Haven't implemented norm reduction |
236 | 258 |
|
237 | 259 | f(resid, integrator.du, u0, p, t) |
238 | | - integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return |
| 260 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return |
239 | 261 |
|
240 | 262 | # _du and _u should be non-dual since NonlinearSolve does not differentiate the solver |
241 | 263 | # These non-dual values are thus used to make the caches |
@@ -316,7 +338,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
316 | 338 | nlequation = (u, _) -> nlequation_oop(u) |
317 | 339 |
|
318 | 340 | resid = f(integrator.du, u0, p, t) |
319 | | - integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return |
| 341 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return |
320 | 342 |
|
321 | 343 | jac = if isnothing(f.jac) |
322 | 344 | f.jac |
@@ -381,7 +403,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, |
381 | 403 |
|
382 | 404 | tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) |
383 | 405 |
|
384 | | - integrator.opts.internalnorm(tmp, t) <= alg.abstol && return |
| 406 | + check_dae_tolerance(integrator, tmp, alg.abstol, t, isinplace) && return |
385 | 407 | alg_u = @view u[algebraic_vars] |
386 | 408 |
|
387 | 409 | # These non-dual values are thus used to make the caches |
@@ -460,7 +482,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, |
460 | 482 | du = f(u0, p, t) |
461 | 483 | resid = _vec(du)[algebraic_eqs] |
462 | 484 |
|
463 | | - integrator.opts.internalnorm(resid, t) <= alg.abstol && return |
| 485 | + check_dae_tolerance(integrator, resid, alg.abstol, t, isinplace) && return |
464 | 486 |
|
465 | 487 | isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff |
466 | 488 | if isAD |
@@ -539,7 +561,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
539 | 561 | normtmp = get_tmp_cache(integrator)[1] |
540 | 562 | f(normtmp, du, u, p, t) |
541 | 563 |
|
542 | | - if integrator.opts.internalnorm(normtmp, t) <= alg.abstol |
| 564 | + if check_dae_tolerance(integrator, normtmp, alg.abstol, t, isinplace) |
543 | 565 | return |
544 | 566 | elseif differential_vars === nothing |
545 | 567 | error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.") |
@@ -600,7 +622,8 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
600 | 622 | @unpack p, t, f = integrator |
601 | 623 | differential_vars = prob.differential_vars |
602 | 624 |
|
603 | | - if integrator.opts.internalnorm(f(integrator.du, integrator.u, p, t), t) <= alg.abstol |
| 625 | + if check_dae_tolerance( |
| 626 | + integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t, isinplace) |
604 | 627 | return |
605 | 628 | elseif differential_vars === nothing |
606 | 629 | error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.") |
|
0 commit comments