Skip to content

Commit 7f7adb2

Browse files
Merge pull request #2873 from ChrisRackauckas-Claude/diffeqbase-initialization-algorithms
Update to use DiffEqBase initialization algorithms
2 parents 9088786 + d0db608 commit 7f7adb2

File tree

6 files changed

+33
-49
lines changed

6 files changed

+33
-49
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ Adapt = "4.3"
112112
ArrayInterface = "7.19"
113113
CommonSolve = "0.2.4"
114114
DataStructures = "0.18.22, 0.19"
115-
DiffEqBase = "6.186"
115+
DiffEqBase = "6.190.2"
116116
DocStringExtensions = "0.9.5"
117117
EnumX = "1.0.5"
118118
ExplicitImports = "1.13.1"

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ FastPower = "1.1"
8585
Logging = "1.10"
8686
Mooncake = "0.4"
8787
AllocCheck = "0.2"
88-
DiffEqBase = "6.187"
88+
DiffEqBase = "6.190.2"
8989
FillArrays = "1.13"
9090
Adapt = "4.3"
9191
Reexport = "1.2"

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ import FastPower: fastpower
2727
import SciMLBase: solve!, step!, isadaptive
2828
import DiffEqBase: initialize!
2929

30+
# DAE Initialization algorithms
31+
import DiffEqBase: DefaultInit, ShampineCollocationInit, BrownFullBasicInit
32+
3033
# Internal utils
3134
import DiffEqBase: ODE_DEFAULT_NORM,
3235
ODE_DEFAULT_ISOUTOFDOMAIN, ODE_DEFAULT_PROG_MESSAGE,

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,3 @@
1-
struct DefaultInit <: SciMLBase.DAEInitializationAlgorithm end
2-
3-
struct ShampineCollocationInit{T, F} <: SciMLBase.DAEInitializationAlgorithm
4-
initdt::T
5-
nlsolve::F
6-
end
7-
function ShampineCollocationInit(; initdt = nothing, nlsolve = nothing)
8-
ShampineCollocationInit(initdt, nlsolve)
9-
end
10-
function ShampineCollocationInit(initdt)
11-
ShampineCollocationInit(; initdt = initdt, nlsolve = nothing)
12-
end
13-
14-
struct BrownFullBasicInit{T, F} <: SciMLBase.DAEInitializationAlgorithm
15-
abstol::T
16-
nlsolve::F
17-
end
18-
function BrownFullBasicInit(; abstol = 1e-10, nlsolve = nothing)
19-
BrownFullBasicInit(abstol, nlsolve)
20-
end
21-
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)
22-
231
## Notes
242

253
#=
@@ -44,7 +22,7 @@ end
4422

4523
## Default algorithms
4624

47-
function _initialize_dae!(integrator, prob::ODEProblem,
25+
function _initialize_dae!(integrator::ODEIntegrator, prob::ODEProblem,
4826
alg::DefaultInit, x::Union{Val{true}, Val{false}})
4927
if SciMLBase.has_initializeprob(prob.f)
5028
_initialize_dae!(integrator, prob,
@@ -58,7 +36,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
5836
end
5937
end
6038

61-
function _initialize_dae!(integrator, prob::DAEProblem,
39+
function _initialize_dae!(integrator::ODEIntegrator, prob::DAEProblem,
6240
alg::DefaultInit, x::Union{Val{true}, Val{false}})
6341
if SciMLBase.has_initializeprob(prob.f)
6442
_initialize_dae!(integrator, prob,
@@ -77,7 +55,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
7755
end
7856
end
7957

80-
function _initialize_dae!(integrator, prob::DiscreteProblem,
58+
function _initialize_dae!(integrator::ODEIntegrator, prob::DiscreteProblem,
8159
alg::DefaultInit, x::Union{Val{true}, Val{false}})
8260
if SciMLBase.has_initializeprob(prob.f)
8361
# integrator.opts.abstol is `false` for `DiscreteProblem`.
@@ -124,13 +102,13 @@ end
124102

125103
## NoInit
126104

127-
function _initialize_dae!(integrator, prob::AbstractDEProblem,
105+
function _initialize_dae!(integrator::ODEIntegrator, prob::AbstractDEProblem,
128106
alg::NoInit, x::Union{Val{true}, Val{false}})
129107
end
130108

131109
## OverrideInit
132110

133-
function _initialize_dae!(integrator, prob::AbstractDEProblem,
111+
function _initialize_dae!(integrator::ODEIntegrator, prob::AbstractDEProblem,
134112
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
135113
initializeprob = prob.f.initialization_data.initializeprob
136114

@@ -172,7 +150,7 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem,
172150
end
173151

174152
## CheckInit
175-
function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::CheckInit,
153+
function _initialize_dae!(integrator::ODEIntegrator, prob::AbstractDEProblem, alg::CheckInit,
176154
isinplace::Union{Val{true}, Val{false}})
177155
SciMLBase.get_initial_values(
178156
prob, integrator, prob.f, alg, isinplace; abstol = integrator.opts.abstol)

lib/OrdinaryDiffEqNonlinearSolve/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ RecursiveArrayTools = "3.36"
6363
ODEProblemLibrary = "0.1.8"
6464
PreallocationTools = "0.4"
6565
AllocCheck = "0.2"
66-
DiffEqBase = "6.176"
66+
DiffEqBase = "6.190.2"
6767
SafeTestsets = "0.1.0"
6868
SciMLOperators = "1.4"
6969
SciMLStructures = "1.7"

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,20 @@ Solve for `u`
5353
5454
=#
5555

56-
function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocationInit,
56+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::ODEProblem, alg::DiffEqBase.ShampineCollocationInit,
5757
isinplace::Val{true})
5858
@unpack p, t, f = integrator
5959
M = integrator.f.mass_matrix
6060
dtmax = integrator.opts.dtmax
6161
tmp = first(get_tmp_cache(integrator))
6262
u0 = integrator.u
6363

64-
dt = if alg.initdt === nothing
64+
initdt = alg.initdt
65+
dt = if initdt === nothing
6566
integrator.dt != 0 ? min(integrator.dt / 5, dtmax) :
6667
(prob.tspan[end] - prob.tspan[begin]) / 1000 # Haven't implemented norm reduction
6768
else
68-
alg.initdt
69+
initdt
6970
end
7071

7172
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
@@ -168,18 +169,19 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
168169
return
169170
end
170171

171-
function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocationInit,
172+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::ODEProblem, alg::DiffEqBase.ShampineCollocationInit,
172173
isinplace::Val{false})
173174
@unpack p, t, f = integrator
174175
u0 = integrator.u
175176
M = integrator.f.mass_matrix
176177
dtmax = integrator.opts.dtmax
177178

178-
dt = if alg.initdt === nothing
179+
initdt = alg.initdt
180+
dt = if initdt === nothing
179181
integrator.dt != 0 ? min(integrator.dt / 5, dtmax) :
180182
(prob.tspan[end] - prob.tspan[begin]) / 1000 # Haven't implemented norm reduction
181183
else
182-
alg.initdt
184+
initdt
183185
end
184186

185187
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
@@ -246,7 +248,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
246248
return
247249
end
248250

249-
function _initialize_dae!(integrator, prob::DAEProblem,
251+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::DAEProblem,
250252
alg::ShampineCollocationInit, isinplace::Val{true})
251253
@unpack p, t, f = integrator
252254
u0 = integrator.u
@@ -323,7 +325,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
323325
return
324326
end
325327

326-
function _initialize_dae!(integrator, prob::DAEProblem,
328+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::DAEProblem,
327329
alg::ShampineCollocationInit, isinplace::Val{false})
328330
@unpack p, t, f = integrator
329331
u0 = integrator.u
@@ -387,8 +389,8 @@ function algebraic_jacobian(jac_prototype::T, algebraic_eqs,
387389
jac_prototype[algebraic_eqs, algebraic_vars]
388390
end
389391

390-
function _initialize_dae!(integrator, prob::ODEProblem,
391-
alg::BrownFullBasicInit, isinplace::Val{true})
392+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::ODEProblem,
393+
alg::DiffEqBase.BrownFullBasicInit, isinplace::Val{true})
392394
@unpack p, t, f = integrator
393395
u = integrator.u
394396
M = integrator.f.mass_matrix
@@ -468,8 +470,8 @@ function _initialize_dae!(integrator, prob::ODEProblem,
468470
return
469471
end
470472

471-
function _initialize_dae!(integrator, prob::ODEProblem,
472-
alg::BrownFullBasicInit, isinplace::Val{false})
473+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::ODEProblem,
474+
alg::DiffEqBase.BrownFullBasicInit, isinplace::Val{false})
473475
@unpack p, t, f = integrator
474476

475477
u0 = integrator.u
@@ -536,8 +538,8 @@ function _initialize_dae!(integrator, prob::ODEProblem,
536538
return
537539
end
538540

539-
function _initialize_dae!(integrator, prob::DAEProblem,
540-
alg::BrownFullBasicInit, isinplace::Val{true})
541+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::DAEProblem,
542+
alg::DiffEqBase.BrownFullBasicInit, isinplace::Val{true})
541543
@unpack p, t, f = integrator
542544
differential_vars = prob.differential_vars
543545
u = integrator.u
@@ -592,8 +594,9 @@ function _initialize_dae!(integrator, prob::DAEProblem,
592594
f(out, du_tmp, uu, p, t)
593595
end
594596

595-
if alg.nlsolve !== nothing
596-
nlsolve = alg.nlsolve
597+
nlsolve_alg = alg.nlsolve
598+
if nlsolve_alg !== nothing
599+
nlsolve = nlsolve_alg
597600
else
598601
nlsolve = NewtonRaphson(autodiff = alg_autodiff(integrator.alg))
599602
end
@@ -617,8 +620,8 @@ function _initialize_dae!(integrator, prob::DAEProblem,
617620
return
618621
end
619622

620-
function _initialize_dae!(integrator, prob::DAEProblem,
621-
alg::BrownFullBasicInit, isinplace::Val{false})
623+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::DAEProblem,
624+
alg::DiffEqBase.BrownFullBasicInit, isinplace::Val{false})
622625
@unpack p, t, f = integrator
623626
differential_vars = prob.differential_vars
624627

0 commit comments

Comments
 (0)