|
1 | 1 | # Optimized tolerance checking that avoids allocations |
2 | | -@inline function check_dae_tolerance(integrator, err, abstol, t) |
| 2 | +@inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{true}) |
3 | 3 | if abstol isa Number |
4 | 4 | return integrator.opts.internalnorm(err, t) / abstol <= 1 |
5 | 5 | else |
6 | | - @. err = err / abstol |
| 6 | + @. err = err / abstol # Safe for in-place functions |
7 | 7 | return integrator.opts.internalnorm(err, t) <= 1 |
8 | 8 | end |
9 | 9 | end |
10 | 10 |
|
| 11 | +@inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{false}) |
| 12 | + if abstol isa Number |
| 13 | + return integrator.opts.internalnorm(err, t) / abstol <= 1 |
| 14 | + else |
| 15 | + return integrator.opts.internalnorm(err ./ abstol, t) <= 1 # Allocates for out-of-place |
| 16 | + end |
| 17 | +end |
| 18 | + |
11 | 19 | function default_nlsolve( |
12 | 20 | ::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false) |
13 | 21 | FastShortcutNonlinearPolyalg(; |
@@ -67,7 +75,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation |
67 | 75 | f(tmp, u0, p, t) |
68 | 76 | tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) |
69 | 77 |
|
70 | | - check_dae_tolerance(integrator, tmp, integrator.opts.abstol, t) && return |
| 78 | + check_dae_tolerance(integrator, tmp, integrator.opts.abstol, t, isinplace) && return |
71 | 79 |
|
72 | 80 | if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve) |
73 | 81 | # backward Euler |
@@ -181,7 +189,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation |
181 | 189 | du = f(u0, p, t) |
182 | 190 | resid = _vec(du)[algebraic_eqs] |
183 | 191 |
|
184 | | - check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return |
| 192 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return |
185 | 193 |
|
186 | 194 | if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve) |
187 | 195 | # backward Euler |
@@ -249,7 +257,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
249 | 257 | dt = t != 0 ? min(t / 1000, dtmax / 10) : dtmax / 10 # Haven't implemented norm reduction |
250 | 258 |
|
251 | 259 | f(resid, integrator.du, u0, p, t) |
252 | | - check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return |
| 260 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return |
253 | 261 |
|
254 | 262 | # _du and _u should be non-dual since NonlinearSolve does not differentiate the solver |
255 | 263 | # These non-dual values are thus used to make the caches |
@@ -330,7 +338,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
330 | 338 | nlequation = (u, _) -> nlequation_oop(u) |
331 | 339 |
|
332 | 340 | resid = f(integrator.du, u0, p, t) |
333 | | - check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return |
| 341 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return |
334 | 342 |
|
335 | 343 | jac = if isnothing(f.jac) |
336 | 344 | f.jac |
@@ -395,7 +403,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, |
395 | 403 |
|
396 | 404 | tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) |
397 | 405 |
|
398 | | - check_dae_tolerance(integrator, tmp, alg.abstol, t) && return |
| 406 | + check_dae_tolerance(integrator, tmp, alg.abstol, t, isinplace) && return |
399 | 407 | alg_u = @view u[algebraic_vars] |
400 | 408 |
|
401 | 409 | # These non-dual values are thus used to make the caches |
@@ -474,7 +482,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, |
474 | 482 | du = f(u0, p, t) |
475 | 483 | resid = _vec(du)[algebraic_eqs] |
476 | 484 |
|
477 | | - check_dae_tolerance(integrator, resid, alg.abstol, t) && return |
| 485 | + check_dae_tolerance(integrator, resid, alg.abstol, t, isinplace) && return |
478 | 486 |
|
479 | 487 | isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff |
480 | 488 | if isAD |
@@ -553,7 +561,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
553 | 561 | normtmp = get_tmp_cache(integrator)[1] |
554 | 562 | f(normtmp, du, u, p, t) |
555 | 563 |
|
556 | | - if check_dae_tolerance(integrator, normtmp, alg.abstol, t) |
| 564 | + if check_dae_tolerance(integrator, normtmp, alg.abstol, t, isinplace) |
557 | 565 | return |
558 | 566 | elseif differential_vars === nothing |
559 | 567 | error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.") |
@@ -614,7 +622,8 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
614 | 622 | @unpack p, t, f = integrator |
615 | 623 | differential_vars = prob.differential_vars |
616 | 624 |
|
617 | | - if check_dae_tolerance(integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t) |
| 625 | + if check_dae_tolerance( |
| 626 | + integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t, isinplace) |
618 | 627 | return |
619 | 628 | elseif differential_vars === nothing |
620 | 629 | error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.") |
|
0 commit comments