Skip to content

Commit ddf2c2d

Browse files
Merge pull request #2868 from ChrisRackauckas-Claude/fix-vector-abstol-dae-2820
Fix DAE mass matrix initialization with vector abstol (fixes #2820)
2 parents 9b97006 + 89fc779 commit ddf2c2d

File tree

1 file changed

+37
-14
lines changed

1 file changed

+37
-14
lines changed

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
# Optimized tolerance checking that avoids allocations
2+
@inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{true})
3+
if abstol isa Number
4+
return integrator.opts.internalnorm(err, t) / abstol <= 1
5+
else
6+
@. err = err / abstol # Safe for in-place functions
7+
return integrator.opts.internalnorm(err, t) <= 1
8+
end
9+
end
10+
11+
@inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{false})
12+
if abstol isa Number
13+
return integrator.opts.internalnorm(err, t) / abstol <= 1
14+
else
15+
return integrator.opts.internalnorm(err ./ abstol, t) <= 1 # Allocates for out-of-place
16+
end
17+
end
18+
119
function default_nlsolve(
220
::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false)
321
FastShortcutNonlinearPolyalg(;
@@ -57,22 +75,24 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
5775
f(tmp, u0, p, t)
5876
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
5977

60-
integrator.opts.internalnorm(tmp, t) <= integrator.opts.abstol && return
78+
check_dae_tolerance(integrator, tmp, integrator.opts.abstol, t, isinplace) && return
6179

6280
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
6381
# backward Euler
6482
nlsolver = integrator.cache.nlsolver
65-
oldγ, oldc, oldmethod, olddt = nlsolver.γ, nlsolver.c, nlsolver.method,
83+
oldγ, oldc, oldmethod,
84+
olddt = nlsolver.γ, nlsolver.c, nlsolver.method,
6685
integrator.dt
6786
nlsolver.tmp .= integrator.uprev
6887
nlsolver.γ, nlsolver.c = 1, 1
6988
nlsolver.method = DIRK
7089
integrator.dt = dt
7190
z = nlsolve!(nlsolver, integrator, integrator.cache)
72-
nlsolver.γ, nlsolver.c, nlsolver.method, integrator.dt = oldγ, oldc, oldmethod,
91+
nlsolver.γ, nlsolver.c, nlsolver.method,
92+
integrator.dt = oldγ, oldc, oldmethod,
7393
olddt
7494
failed = nlsolvefail(nlsolver)
75-
@.. broadcast=false integrator.u=integrator.uprev + z
95+
@.. broadcast=false integrator.u=integrator.uprev+z
7696
else
7797

7898
# _u0 should be non-dual since NonlinearSolve does not differentiate the solver
@@ -169,22 +189,24 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
169189
du = f(u0, p, t)
170190
resid = _vec(du)[algebraic_eqs]
171191

172-
integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return
192+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return
173193

174194
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
175195
# backward Euler
176196
nlsolver = integrator.cache.nlsolver
177-
oldγ, oldc, oldmethod, olddt = nlsolver.γ, nlsolver.c, nlsolver.method,
197+
oldγ, oldc, oldmethod,
198+
olddt = nlsolver.γ, nlsolver.c, nlsolver.method,
178199
integrator.dt
179200
nlsolver.tmp .= integrator.uprev
180201
nlsolver.γ, nlsolver.c = 1, 1
181202
nlsolver.method = DIRK
182203
integrator.dt = dt
183204
z = nlsolve!(nlsolver, integrator, integrator.cache)
184-
nlsolver.γ, nlsolver.c, nlsolver.method, integrator.dt = oldγ, oldc, oldmethod,
205+
nlsolver.γ, nlsolver.c, nlsolver.method,
206+
integrator.dt = oldγ, oldc, oldmethod,
185207
olddt
186208
failed = nlsolvefail(nlsolver)
187-
@.. broadcast=false integrator.u=integrator.uprev + z
209+
@.. broadcast=false integrator.u=integrator.uprev+z
188210
else
189211
nlequation_oop = @closure (u, _) -> begin
190212
update_coefficients!(M, u, p, t)
@@ -235,7 +257,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
235257
dt = t != 0 ? min(t / 1000, dtmax / 10) : dtmax / 10 # Haven't implemented norm reduction
236258

237259
f(resid, integrator.du, u0, p, t)
238-
integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return
260+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return
239261

240262
# _du and _u should be non-dual since NonlinearSolve does not differentiate the solver
241263
# These non-dual values are thus used to make the caches
@@ -316,7 +338,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
316338
nlequation = (u, _) -> nlequation_oop(u)
317339

318340
resid = f(integrator.du, u0, p, t)
319-
integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return
341+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return
320342

321343
jac = if isnothing(f.jac)
322344
f.jac
@@ -381,7 +403,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
381403

382404
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
383405

384-
integrator.opts.internalnorm(tmp, t) <= alg.abstol && return
406+
check_dae_tolerance(integrator, tmp, alg.abstol, t, isinplace) && return
385407
alg_u = @view u[algebraic_vars]
386408

387409
# These non-dual values are thus used to make the caches
@@ -460,7 +482,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
460482
du = f(u0, p, t)
461483
resid = _vec(du)[algebraic_eqs]
462484

463-
integrator.opts.internalnorm(resid, t) <= alg.abstol && return
485+
check_dae_tolerance(integrator, resid, alg.abstol, t, isinplace) && return
464486

465487
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff
466488
if isAD
@@ -539,7 +561,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
539561
normtmp = get_tmp_cache(integrator)[1]
540562
f(normtmp, du, u, p, t)
541563

542-
if integrator.opts.internalnorm(normtmp, t) <= alg.abstol
564+
if check_dae_tolerance(integrator, normtmp, alg.abstol, t, isinplace)
543565
return
544566
elseif differential_vars === nothing
545567
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")
@@ -600,7 +622,8 @@ function _initialize_dae!(integrator, prob::DAEProblem,
600622
@unpack p, t, f = integrator
601623
differential_vars = prob.differential_vars
602624

603-
if integrator.opts.internalnorm(f(integrator.du, integrator.u, p, t), t) <= alg.abstol
625+
if check_dae_tolerance(
626+
integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t, isinplace)
604627
return
605628
elseif differential_vars === nothing
606629
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")

0 commit comments

Comments
 (0)