Skip to content

Commit 89fc779

Browse files
Add isinplace parameter to handle static arrays and non-mutable types safely
Critical fix: dispatch on isinplace to avoid mutating immutable types. For in-place functions (Val{true}): - @. err = err / abstol # Safe to mutate, zero allocation For out-of-place functions (Val{false}): - err ./ abstol # Must allocate, but handles StaticArrays correctly This prevents errors with StaticArrays and other immutable types while maintaining zero allocations for in-place cases. Tested with: - Regular mutable arrays (in-place): ✅ Zero allocation - StaticArrays (out-of-place): ✅ Works correctly with allocation - Scalar abstol: ✅ Zero allocation for both cases 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 6e2d92b commit 89fc779

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
# Optimized tolerance checking that avoids allocations
2-
@inline function check_dae_tolerance(integrator, err, abstol, t)
2+
@inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{true})
33
if abstol isa Number
44
return integrator.opts.internalnorm(err, t) / abstol <= 1
55
else
6-
@. err = err / abstol
6+
@. err = err / abstol # Safe for in-place functions
77
return integrator.opts.internalnorm(err, t) <= 1
88
end
99
end
1010

11+
@inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{false})
12+
if abstol isa Number
13+
return integrator.opts.internalnorm(err, t) / abstol <= 1
14+
else
15+
return integrator.opts.internalnorm(err ./ abstol, t) <= 1 # Allocates for out-of-place
16+
end
17+
end
18+
1119
function default_nlsolve(
1220
::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false)
1321
FastShortcutNonlinearPolyalg(;
@@ -67,7 +75,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
6775
f(tmp, u0, p, t)
6876
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
6977

70-
check_dae_tolerance(integrator, tmp, integrator.opts.abstol, t) && return
78+
check_dae_tolerance(integrator, tmp, integrator.opts.abstol, t, isinplace) && return
7179

7280
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
7381
# backward Euler
@@ -181,7 +189,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
181189
du = f(u0, p, t)
182190
resid = _vec(du)[algebraic_eqs]
183191

184-
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return
192+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return
185193

186194
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
187195
# backward Euler
@@ -249,7 +257,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
249257
dt = t != 0 ? min(t / 1000, dtmax / 10) : dtmax / 10 # Haven't implemented norm reduction
250258

251259
f(resid, integrator.du, u0, p, t)
252-
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return
260+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return
253261

254262
# _du and _u should be non-dual since NonlinearSolve does not differentiate the solver
255263
# These non-dual values are thus used to make the caches
@@ -330,7 +338,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
330338
nlequation = (u, _) -> nlequation_oop(u)
331339

332340
resid = f(integrator.du, u0, p, t)
333-
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t) && return
341+
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return
334342

335343
jac = if isnothing(f.jac)
336344
f.jac
@@ -395,7 +403,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
395403

396404
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
397405

398-
check_dae_tolerance(integrator, tmp, alg.abstol, t) && return
406+
check_dae_tolerance(integrator, tmp, alg.abstol, t, isinplace) && return
399407
alg_u = @view u[algebraic_vars]
400408

401409
# These non-dual values are thus used to make the caches
@@ -474,7 +482,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
474482
du = f(u0, p, t)
475483
resid = _vec(du)[algebraic_eqs]
476484

477-
check_dae_tolerance(integrator, resid, alg.abstol, t) && return
485+
check_dae_tolerance(integrator, resid, alg.abstol, t, isinplace) && return
478486

479487
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff
480488
if isAD
@@ -553,7 +561,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
553561
normtmp = get_tmp_cache(integrator)[1]
554562
f(normtmp, du, u, p, t)
555563

556-
if check_dae_tolerance(integrator, normtmp, alg.abstol, t)
564+
if check_dae_tolerance(integrator, normtmp, alg.abstol, t, isinplace)
557565
return
558566
elseif differential_vars === nothing
559567
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")
@@ -614,7 +622,8 @@ function _initialize_dae!(integrator, prob::DAEProblem,
614622
@unpack p, t, f = integrator
615623
differential_vars = prob.differential_vars
616624

617-
if check_dae_tolerance(integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t)
625+
if check_dae_tolerance(
626+
integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t, isinplace)
618627
return
619628
elseif differential_vars === nothing
620629
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")

0 commit comments

Comments
 (0)