Skip to content

Commit b0eec3f

Browse files
Optimize DAE tolerance checking to avoid allocations with scalar abstol
Add optimized helper function check_dae_tolerance() that: - Uses zero-allocation path for scalar abstol: internalnorm(err, t) / abstol <= 1 - Uses array division path for vector abstol: internalnorm(err ./ abstol, t) <= 1 Performance improvement: - Scalar abstol: 0 allocations (was 2 allocations, 80 bytes) - Vector abstol: Same performance (2 allocations as expected) This addresses performance regression concerns while maintaining the vector abstol functionality fix from the previous commit. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent bbf745e commit b0eec3f

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# Optimized tolerance checking that avoids allocations for scalar abstol
2+
@inline function check_dae_tolerance(internalnorm, err, abstol, t)
3+
if abstol isa Number
4+
return internalnorm(err, t) / abstol <= 1
5+
else
6+
return internalnorm(err ./ abstol, t) <= 1
7+
end
8+
end
9+
110
function default_nlsolve(
211
::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false)
312
FastShortcutNonlinearPolyalg(;
@@ -57,7 +66,8 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
5766
f(tmp, u0, p, t)
5867
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
5968

60-
integrator.opts.internalnorm(tmp ./ integrator.opts.abstol, t) <= 1 && return
69+
check_dae_tolerance(integrator.opts.internalnorm, tmp, integrator.opts.abstol, t) &&
70+
return
6171

6272
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
6373
# backward Euler
@@ -171,7 +181,8 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
171181
du = f(u0, p, t)
172182
resid = _vec(du)[algebraic_eqs]
173183

174-
integrator.opts.internalnorm(resid ./ integrator.opts.abstol, t) <= 1 && return
184+
check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) &&
185+
return
175186

176187
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
177188
# backward Euler
@@ -239,7 +250,8 @@ function _initialize_dae!(integrator, prob::DAEProblem,
239250
dt = t != 0 ? min(t / 1000, dtmax / 10) : dtmax / 10 # Haven't implemented norm reduction
240251

241252
f(resid, integrator.du, u0, p, t)
242-
integrator.opts.internalnorm(resid ./ integrator.opts.abstol, t) <= 1 && return
253+
check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) &&
254+
return
243255

244256
# _du and _u should be non-dual since NonlinearSolve does not differentiate the solver
245257
# These non-dual values are thus used to make the caches
@@ -320,7 +332,8 @@ function _initialize_dae!(integrator, prob::DAEProblem,
320332
nlequation = (u, _) -> nlequation_oop(u)
321333

322334
resid = f(integrator.du, u0, p, t)
323-
integrator.opts.internalnorm(resid ./ integrator.opts.abstol, t) <= 1 && return
335+
check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) &&
336+
return
324337

325338
jac = if isnothing(f.jac)
326339
f.jac
@@ -385,7 +398,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
385398

386399
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
387400

388-
integrator.opts.internalnorm(tmp ./ alg.abstol, t) <= 1 && return
401+
check_dae_tolerance(integrator.opts.internalnorm, tmp, alg.abstol, t) && return
389402
alg_u = @view u[algebraic_vars]
390403

391404
# These non-dual values are thus used to make the caches
@@ -464,7 +477,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
464477
du = f(u0, p, t)
465478
resid = _vec(du)[algebraic_eqs]
466479

467-
integrator.opts.internalnorm(resid ./ alg.abstol, t) <= 1 && return
480+
check_dae_tolerance(integrator.opts.internalnorm, resid, alg.abstol, t) && return
468481

469482
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff
470483
if isAD
@@ -543,7 +556,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
543556
normtmp = get_tmp_cache(integrator)[1]
544557
f(normtmp, du, u, p, t)
545558

546-
if integrator.opts.internalnorm(normtmp ./ alg.abstol, t) <= 1
559+
if check_dae_tolerance(integrator.opts.internalnorm, normtmp, alg.abstol, t)
547560
return
548561
elseif differential_vars === nothing
549562
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")
@@ -604,8 +617,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
604617
@unpack p, t, f = integrator
605618
differential_vars = prob.differential_vars
606619

607-
if integrator.opts.internalnorm(f(integrator.du, integrator.u, p, t) ./ alg.abstol, t) <=
608-
1
620+
if check_dae_tolerance(integrator.opts.internalnorm, f(integrator.du, integrator.u, p, t), alg.abstol, t)
609621
return
610622
elseif differential_vars === nothing
611623
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")

0 commit comments

Comments
 (0)