|
1 | 1 | # Optimized tolerance checking that avoids allocations for scalar abstol |
2 | | -@inline function check_dae_tolerance(internalnorm, err, abstol, t) |
| 2 | +@inline function check_dae_tolerance(integrator, err, abstol, t) |
3 | 3 | if abstol isa Number |
4 | | - return internalnorm(err, t) / abstol <= 1 |
| 4 | + return integrator.opts.internalnorm(err, t) / abstol <= 1 |
5 | 5 | else |
6 | | - return internalnorm(err ./ abstol, t) <= 1 |
| 6 | + return integrator.opts.internalnorm(err ./ abstol, t) <= 1 |
7 | 7 | end |
8 | 8 | end |
9 | 9 |
|
@@ -66,8 +66,13 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation |
66 | 66 | f(tmp, u0, p, t) |
67 | 67 | tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) |
68 | 68 |
|
69 | | - check_dae_tolerance(integrator.opts.internalnorm, tmp, integrator.opts.abstol, t) && |
70 | | - return |
| 69 | + # Zero-allocation tolerance check reusing tmp |
| 70 | + if integrator.opts.abstol isa Number |
| 71 | + integrator.opts.internalnorm(tmp, t) / integrator.opts.abstol <= 1 && return |
| 72 | + else |
| 73 | + @. tmp = tmp / integrator.opts.abstol |
| 74 | + integrator.opts.internalnorm(tmp, t) <= 1 && return |
| 75 | + end |
71 | 76 |
|
72 | 77 | if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve) |
73 | 78 | # backward Euler |
@@ -181,8 +186,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation |
181 | 186 | du = f(u0, p, t) |
182 | 187 | resid = _vec(du)[algebraic_eqs] |
183 | 188 |
|
184 | | - check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) && |
185 | | - return |
| 189 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return |
186 | 190 |
|
187 | 191 | if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve) |
188 | 192 | # backward Euler |
@@ -250,8 +254,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
250 | 254 | dt = t != 0 ? min(t / 1000, dtmax / 10) : dtmax / 10 # Haven't implemented norm reduction |
251 | 255 |
|
252 | 256 | f(resid, integrator.du, u0, p, t) |
253 | | - check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) && |
254 | | - return |
| 257 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return |
255 | 258 |
|
256 | 259 | # _du and _u should be non-dual since NonlinearSolve does not differentiate the solver |
257 | 260 | # These non-dual values are thus used to make the caches |
@@ -332,8 +335,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
332 | 335 | nlequation = (u, _) -> nlequation_oop(u) |
333 | 336 |
|
334 | 337 | resid = f(integrator.du, u0, p, t) |
335 | | - check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) && |
336 | | - return |
| 338 | + check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return |
337 | 339 |
|
338 | 340 | jac = if isnothing(f.jac) |
339 | 341 | f.jac |
@@ -398,7 +400,13 @@ function _initialize_dae!(integrator, prob::ODEProblem, |
398 | 400 |
|
399 | 401 | tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) |
400 | 402 |
|
401 | | - check_dae_tolerance(integrator.opts.internalnorm, tmp, alg.abstol, t) && return |
| 403 | + # Zero-allocation tolerance check reusing tmp |
| 404 | + if alg.abstol isa Number |
| 405 | + integrator.opts.internalnorm(tmp, t) / alg.abstol <= 1 && return |
| 406 | + else |
| 407 | + @. tmp = tmp / alg.abstol |
| 408 | + integrator.opts.internalnorm(tmp, t) <= 1 && return |
| 409 | + end |
402 | 410 | alg_u = @view u[algebraic_vars] |
403 | 411 |
|
404 | 412 | # These non-dual values are thus used to make the caches |
@@ -477,7 +485,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, |
477 | 485 | du = f(u0, p, t) |
478 | 486 | resid = _vec(du)[algebraic_eqs] |
479 | 487 |
|
480 | | - check_dae_tolerance(integrator.opts.internalnorm, resid, alg.abstol, t) && return |
| 488 | + check_dae_tolerance(integrator, resid, alg.abstol, t) && return |
481 | 489 |
|
482 | 490 | isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff |
483 | 491 | if isAD |
@@ -556,7 +564,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
556 | 564 | normtmp = get_tmp_cache(integrator)[1] |
557 | 565 | f(normtmp, du, u, p, t) |
558 | 566 |
|
559 | | - if check_dae_tolerance(integrator.opts.internalnorm, normtmp, alg.abstol, t) |
| 567 | + if check_dae_tolerance(integrator, normtmp, alg.abstol, t) |
560 | 568 | return |
561 | 569 | elseif differential_vars === nothing |
562 | 570 | error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.") |
@@ -617,7 +625,7 @@ function _initialize_dae!(integrator, prob::DAEProblem, |
617 | 625 | @unpack p, t, f = integrator |
618 | 626 | differential_vars = prob.differential_vars |
619 | 627 |
|
620 | | - if check_dae_tolerance(integrator.opts.internalnorm, f(integrator.du, integrator.u, p, t), alg.abstol, t) |
| 628 | + if check_dae_tolerance(integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t) |
621 | 629 | return |
622 | 630 | elseif differential_vars === nothing |
623 | 631 | error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.") |
|
0 commit comments