Skip to content

Commit 7b6701e

Browse files
Further optimize DAE tolerance checking using existing temp arrays
Eliminate allocations in the most performance-critical cases by reusing existing temporary arrays where available: - In-place division: @. tmp = tmp / abstol (zero allocation) - Reuses tmp arrays already allocated in the same scope - Falls back to helper function for cases without available temps Key optimizations: 1. ShampineCollocationInit: Reuse tmp after algebraic restructure 2. BrownFullBasicInit: Reuse tmp after algebraic restructure 3. Other cases: Use helper function with scalar optimization This provides the best balance of performance and code maintainability while completely eliminating allocations where most beneficial. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent b0eec3f commit 7b6701e

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Optimized tolerance checking that avoids allocations for scalar abstol
2-
@inline function check_dae_tolerance(internalnorm, err, abstol, t)
2+
@inline function check_dae_tolerance(integrator, err, abstol, t)
33
if abstol isa Number
4-
return internalnorm(err, t) / abstol <= 1
4+
return integrator.opts.internalnorm(err, t) / abstol <= 1
55
else
6-
return internalnorm(err ./ abstol, t) <= 1
6+
return integrator.opts.internalnorm(err ./ abstol, t) <= 1
77
end
88
end
99

@@ -66,8 +66,13 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
6666
f(tmp, u0, p, t)
6767
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
6868

69-
check_dae_tolerance(integrator.opts.internalnorm, tmp, integrator.opts.abstol, t) &&
70-
return
69+
# Zero-allocation tolerance check reusing tmp
70+
if integrator.opts.abstol isa Number
71+
integrator.opts.internalnorm(tmp, t) / integrator.opts.abstol <= 1 && return
72+
else
73+
@. tmp = tmp / integrator.opts.abstol
74+
integrator.opts.internalnorm(tmp, t) <= 1 && return
75+
end
7176

7277
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
7378
# backward Euler
@@ -181,8 +186,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
181186
du = f(u0, p, t)
182187
resid = _vec(du)[algebraic_eqs]
183188

184-
check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) &&
185-
return
189+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return
186190

187191
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
188192
# backward Euler
@@ -250,8 +254,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
250254
dt = t != 0 ? min(t / 1000, dtmax / 10) : dtmax / 10 # Haven't implemented norm reduction
251255

252256
f(resid, integrator.du, u0, p, t)
253-
check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) &&
254-
return
257+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return
255258

256259
# _du and _u should be non-dual since NonlinearSolve does not differentiate the solver
257260
# These non-dual values are thus used to make the caches
@@ -332,8 +335,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
332335
nlequation = (u, _) -> nlequation_oop(u)
333336

334337
resid = f(integrator.du, u0, p, t)
335-
check_dae_tolerance(integrator.opts.internalnorm, resid, integrator.opts.abstol, t) &&
336-
return
338+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return
337339

338340
jac = if isnothing(f.jac)
339341
f.jac
@@ -398,7 +400,13 @@ function _initialize_dae!(integrator, prob::ODEProblem,
398400

399401
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
400402

401-
check_dae_tolerance(integrator.opts.internalnorm, tmp, alg.abstol, t) && return
403+
# Zero-allocation tolerance check reusing tmp
404+
if alg.abstol isa Number
405+
integrator.opts.internalnorm(tmp, t) / alg.abstol <= 1 && return
406+
else
407+
@. tmp = tmp / alg.abstol
408+
integrator.opts.internalnorm(tmp, t) <= 1 && return
409+
end
402410
alg_u = @view u[algebraic_vars]
403411

404412
# These non-dual values are thus used to make the caches
@@ -477,7 +485,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
477485
du = f(u0, p, t)
478486
resid = _vec(du)[algebraic_eqs]
479487

480-
check_dae_tolerance(integrator.opts.internalnorm, resid, alg.abstol, t) && return
488+
check_dae_tolerance(integrator, resid, alg.abstol, t) && return
481489

482490
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff
483491
if isAD
@@ -556,7 +564,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
556564
normtmp = get_tmp_cache(integrator)[1]
557565
f(normtmp, du, u, p, t)
558566

559-
if check_dae_tolerance(integrator.opts.internalnorm, normtmp, alg.abstol, t)
567+
if check_dae_tolerance(integrator, normtmp, alg.abstol, t)
560568
return
561569
elseif differential_vars === nothing
562570
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")
@@ -617,7 +625,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
617625
@unpack p, t, f = integrator
618626
differential_vars = prob.differential_vars
619627

620-
if check_dae_tolerance(integrator.opts.internalnorm, f(integrator.du, integrator.u, p, t), alg.abstol, t)
628+
if check_dae_tolerance(integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t)
621629
return
622630
elseif differential_vars === nothing
623631
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")

0 commit comments

Comments
 (0)