Skip to content

Commit bf87624

Browse files
committed
Use FunctionWrapper to improve type stability
1 parent e451f40 commit bf87624

File tree

7 files changed

+79
-65
lines changed

7 files changed

+79
-65
lines changed

lib/OrdinaryDiffEqTaylorSeries/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "1.1.0"
66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
88
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
9+
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
1112
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
@@ -23,6 +24,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
2324
DiffEqBase = "6.152.2"
2425
DiffEqDevTools = "2.44.4"
2526
FastBroadcast = "0.3.5"
27+
FunctionWrappers = "1.1.3"
2628
LinearAlgebra = "<0.0.1, 1"
2729
MuladdMacro = "0.2.4"
2830
OrdinaryDiffEqCore = "1.1"

lib/OrdinaryDiffEqTaylorSeries/src/OrdinaryDiffEqTaylorSeries.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ using TaylorDiff, Symbolics
2525
using TaylorDiff: make_seed, get_coefficient, append_coefficient, flatten
2626
import DiffEqBase: @def
2727
import OrdinaryDiffEqCore
28+
using FunctionWrappers
29+
import FunctionWrappers: FunctionWrapper
2830

2931
using Reexport
3032
@reexport using DiffEqBase

lib/OrdinaryDiffEqTaylorSeries/src/TaylorSeries_caches.jl

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ end
3939
get_fsalfirstlast(cache::ExplicitTaylor2Cache, u) = (cache.k1, cache.k1)
4040

4141
@cache struct ExplicitTaylorCache{
42-
P, uType, taylorType, uNoUnitsType, StageLimiter, StepLimiter,
42+
P, tType, uType, taylorType, uNoUnitsType, StageLimiter, StepLimiter,
4343
Thread} <: OrdinaryDiffEqMutableCache
4444
order::Val{P}
45-
jet::Function
45+
jet::FunctionWrapper{Nothing, Tuple{taylorType, uType, tType}}
4646
u::uType
4747
uprev::uType
4848
utaylor::taylorType
@@ -54,45 +54,49 @@ get_fsalfirstlast(cache::ExplicitTaylor2Cache, u) = (cache.k1, cache.k1)
5454
thread::Thread
5555
end
5656

57-
function alg_cache(alg::ExplicitTaylor, u, rate_prototype, ::Type{uEltypeNoUnits},
57+
function alg_cache(alg::ExplicitTaylor{P}, u, rate_prototype, ::Type{uEltypeNoUnits},
5858
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
5959
dt, reltol, p, calck,
60-
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
61-
_, jet_iip = build_jet(f, p, alg.order, length(u))
60+
::Val{true}) where {P, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
61+
_, jet_iip = build_jet(f, p, P, length(u))
6262
utaylor = TaylorDiff.make_seed(u, zero(u), alg.order)
63+
jet_wrapped = FunctionWrapper{Nothing, Tuple{typeof(utaylor), typeof(u), typeof(t)}}(jet_iip)
6364
utilde = zero(u)
6465
atmp = similar(u, uEltypeNoUnits)
6566
recursivefill!(atmp, false)
6667
tmp = zero(u)
67-
ExplicitTaylorCache(alg.order, jet_iip, u, uprev, utaylor, utilde, tmp, atmp,
68+
ExplicitTaylorCache(alg.order, jet_wrapped, u, uprev, utaylor, utilde, tmp, atmp,
6869
alg.stage_limiter!, alg.step_limiter!, alg.thread)
6970
end
7071

7172
get_fsalfirstlast(cache::ExplicitTaylorCache, u) = (cache.u, cache.u)
7273

73-
struct ExplicitTaylorConstantCache{P} <: OrdinaryDiffEqConstantCache
74+
struct ExplicitTaylorConstantCache{P, taylorType, uType, tType} <:
75+
OrdinaryDiffEqConstantCache
7476
order::Val{P}
75-
jet::Function
77+
jet::FunctionWrapper{taylorType, Tuple{uType, tType}}
7678
end
77-
function alg_cache(::ExplicitTaylor{P}, u, rate_prototype, ::Type{uEltypeNoUnits},
79+
function alg_cache(alg::ExplicitTaylor{P}, u, rate_prototype, ::Type{uEltypeNoUnits},
7880
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
7981
dt, reltol, p, calck,
8082
::Val{false}) where {P, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
8183
if u isa AbstractArray
82-
jet, _ = build_jet(f, p, Val(P), length(u))
84+
jet, _ = build_jet(f, p, P, length(u))
8385
else
84-
jet = build_jet(f, p, Val(P))
86+
jet = build_jet(f, p, P)
8587
end
86-
ExplicitTaylorConstantCache(Val(P), jet)
88+
utaylor = TaylorDiff.make_seed(u, zero(u), alg.order) # not used, but needed for type
89+
jet_wrapped = FunctionWrapper{typeof(utaylor), Tuple{typeof(u), typeof(t)}}(jet)
90+
ExplicitTaylorConstantCache(alg.order, jet_wrapped)
8791
end
8892

89-
@cache struct ExplicitTaylorAdaptiveOrderCache{
90-
uType, taylorType, uNoUnitsType, StageLimiter, StepLimiter,
93+
@cache struct ExplicitTaylorAdaptiveOrderCache{P, Q,
94+
tType, uType, taylorType, uNoUnitsType, StageLimiter, StepLimiter,
9195
Thread} <: OrdinaryDiffEqMutableCache
92-
min_order::Int
93-
max_order::Int
94-
current_order::Ref{Int}
95-
jets::Vector{Function}
96+
min_order::Val{P}
97+
max_order::Val{Q}
98+
current_order::Base.RefValue{Int}
99+
jets::Vector{FunctionWrapper{Nothing, Tuple{taylorType, uType, tType}}}
96100
u::uType
97101
uprev::uType
98102
utaylor::taylorType
@@ -108,45 +112,51 @@ function alg_cache(
108112
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
109113
dt, reltol, p, calck,
110114
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
111-
jets = Function[]
112-
for order in (alg.min_order):(alg.max_order)
113-
_, jet_iip = build_jet(f, p, Val(order), length(u))
115+
utaylor = TaylorDiff.make_seed(u, zero(u), alg.max_order)
116+
jets = FunctionWrapper{Nothing, Tuple{typeof(utaylor), typeof(u), typeof(t)}}[]
117+
min_order_value = get_value(alg.min_order)
118+
max_order_value = get_value(alg.max_order)
119+
for order in min_order_value:max_order_value
120+
jet_iip = build_jet(f, p, order, length(u))[2]
114121
push!(jets, jet_iip)
115122
end
116-
utaylor = TaylorDiff.make_seed(u, zero(u), Val(alg.max_order))
117123
utilde = zero(u)
118124
atmp = similar(u, uEltypeNoUnits)
119125
recursivefill!(atmp, false)
120126
tmp = zero(u)
121-
current_order = Ref{Int}(alg.min_order)
127+
current_order = Ref(min_order_value)
122128
ExplicitTaylorAdaptiveOrderCache(alg.min_order, alg.max_order, current_order,
123129
jets, u, uprev, utaylor, utilde, tmp, atmp,
124130
alg.stage_limiter!, alg.step_limiter!, alg.thread)
125131
end
126132

127133
get_fsalfirstlast(cache::ExplicitTaylorAdaptiveOrderCache, u) = (cache.u, cache.u)
128134

129-
struct ExplicitTaylorAdaptiveOrderConstantCache <: OrdinaryDiffEqConstantCache
130-
min_order::Int
131-
max_order::Int
132-
current_order::Ref{Int}
133-
jets::Vector{Function}
135+
struct ExplicitTaylorAdaptiveOrderConstantCache{P, Q, taylorType, uType, tType} <:
136+
OrdinaryDiffEqConstantCache
137+
min_order::Val{P}
138+
max_order::Val{Q}
139+
current_order::Base.RefValue{Int}
140+
jets::Vector{FunctionWrapper{taylorType, Tuple{uType, tType}}}
134141
end
135142
function alg_cache(
136143
alg::ExplicitTaylorAdaptiveOrder, u, rate_prototype, ::Type{uEltypeNoUnits},
137144
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
138145
dt, reltol, p, calck,
139146
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
140-
jets = Function[]
141-
for order in (alg.min_order):(alg.max_order)
147+
utaylor = TaylorDiff.make_seed(u, zero(u), alg.max_order) # not used, but needed for type
148+
jets = FunctionWrapper{typeof(utaylor), Tuple{typeof(u), typeof(t)}}[]
149+
min_order_value = get_value(alg.min_order)
150+
max_order_value = get_value(alg.max_order)
151+
for order in min_order_value:max_order_value
142152
if u isa AbstractArray
143-
jet, _ = build_jet(f, p, Val(order), length(u))
153+
jet, _ = build_jet(f, p, order, length(u))
144154
else
145-
jet = build_jet(f, p, Val(order))
155+
jet = build_jet(f, p, order)
146156
end
147157
push!(jets, jet)
148158
end
149-
current_order = Ref{Int}(alg.min_order)
159+
current_order = Ref(min_order_value)
150160
ExplicitTaylorAdaptiveOrderConstantCache(
151161
alg.min_order, alg.max_order, current_order, jets)
152162
end

lib/OrdinaryDiffEqTaylorSeries/src/TaylorSeries_perform_step.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,6 @@ end
100100
end
101101

102102
function initialize!(integrator, cache::ExplicitTaylorAdaptiveOrderCache)
103-
integrator.kshortsize = cache.max_order
104-
resize!(integrator.k, cache.max_order)
105-
# Setup k pointers
106-
for i in 1:(cache.max_order)
107-
integrator.k[i] = get_coefficient(cache.utaylor, i)
108-
end
109-
return nothing
110103
end
111104

112105
@muladd function perform_step!(
@@ -115,7 +108,9 @@ end
115108
alg = unwrap_alg(integrator, false)
116109
@unpack jets, current_order, min_order, max_order, utaylor, utilde, tmp, atmp, thread = cache
117110

118-
jet_index = current_order[] - min_order + 1
111+
min_order_value = get_value(min_order)
112+
max_order_value = get_value(max_order)
113+
jet_index = current_order[] - min_order_value + 1
119114
# compute one additional order for adaptive order
120115
jet = jets[jet_index + 1]
121116
jet(utaylor, uprev, t)
@@ -125,8 +120,8 @@ end
125120
OrdinaryDiffEqCore.increment_nf!(integrator.stats, current_order[] + 1)
126121
if integrator.opts.adaptive
127122
min_work = Inf
128-
start_order = max(min_order, current_order[] - 1)
129-
end_order = min(max_order, current_order[] + 1)
123+
start_order = max(min_order_value, current_order[] - 1)
124+
end_order = min(max_order_value - 1, current_order[] + 1)
130125
for i in start_order:end_order
131126
A = i * i
132127
@.. broadcast=false thread=thread utilde=TaylorDiff.get_coefficient(
@@ -158,8 +153,9 @@ end
158153
end
159154

160155
function initialize!(integrator, cache::ExplicitTaylorAdaptiveOrderConstantCache)
161-
integrator.kshortsize = cache.max_order
162-
integrator.k = typeof(integrator.k)(undef, cache.max_order)
156+
max_order_value = get_value(cache.max_order)
157+
integrator.kshortsize = max_order_value
158+
integrator.k = typeof(integrator.k)(undef, max_order_value)
163159
return nothing
164160
end
165161

@@ -169,16 +165,18 @@ end
169165
alg = unwrap_alg(integrator, false)
170166
@unpack jets, current_order, min_order, max_order = cache
171167

172-
jet_index = current_order[] - min_order + 1
168+
min_order_value = get_value(min_order)
169+
max_order_value = get_value(max_order)
170+
jet_index = current_order[] - min_order_value + 1
173171
# compute one additional order for adaptive order
174172
jet = jets[jet_index + 1]
175173
utaylor = jet(uprev, t)
176174
u = map(x -> evaluate_polynomial(x, dt), utaylor)
177175
OrdinaryDiffEqCore.increment_nf!(integrator.stats, current_order[] + 1)
178176
if integrator.opts.adaptive
179177
min_work = Inf
180-
start_order = max(min_order, current_order[] - 1)
181-
end_order = min(max_order, current_order[] + 1)
178+
start_order = max(min_order_value, current_order[] - 1)
179+
end_order = min(max_order_value, current_order[] + 1)
182180
for i in start_order:end_order
183181
A = i * i
184182
utilde = TaylorDiff.get_coefficient(utaylor, i) * dt^i

lib/OrdinaryDiffEqTaylorSeries/src/alg_utils.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ alg_stability_size(alg::ExplicitTaylor2) = 1
44
alg_order(::ExplicitTaylor{P}) where {P} = P
55
alg_stability_size(alg::ExplicitTaylor) = 1
66

7-
alg_order(alg::ExplicitTaylorAdaptiveOrder) = alg.min_order
7+
alg_order(alg::ExplicitTaylorAdaptiveOrder) = get_value(alg.min_order)
88
get_current_adaptive_order(::ExplicitTaylorAdaptiveOrder, cache) = cache.current_order[]
99
get_current_alg_order(::ExplicitTaylorAdaptiveOrder, cache) = cache.current_order[]
1010

@@ -16,11 +16,14 @@ function make_term(a)
1616
term(TaylorScalar, Symbolics.unwrap(a.value), map(Symbolics.unwrap, a.partials))
1717
end
1818

19-
function build_jet(f::ODEFunction{iip}, p, order::Val{P}, length = nothing) where {P, iip}
20-
if haskey(JET_CACHE, f)
21-
list = JET_CACHE[f]
22-
index = findfirst(x -> x[1] == order && x[2] == p, list)
23-
index !== nothing && return list[index][3]
19+
function get_value(::Val{P}) where {P}
20+
return P
21+
end
22+
23+
function build_jet(f::ODEFunction{iip}, p, order, length = nothing) where {iip}
24+
key = (f, order, p)
25+
if haskey(JET_CACHE, key)
26+
return JET_CACHE[key]
2427
end
2528
@variables t0::Real
2629
u0 = isnothing(length) ? Symbolics.variable(:u0) : Symbolics.variables(:u0, 1:length)
@@ -32,7 +35,7 @@ function build_jet(f::ODEFunction{iip}, p, order::Val{P}, length = nothing) wher
3235
f0 = f(u0, p, t0)
3336
end
3437
u = TaylorDiff.make_seed(u0, f0, Val(1))
35-
for index in 2:P
38+
for index in 2:order
3639
t = TaylorScalar{index - 1}(t0, one(t0))
3740
if iip
3841
fu = similar(u)
@@ -45,10 +48,9 @@ function build_jet(f::ODEFunction{iip}, p, order::Val{P}, length = nothing) wher
4548
end
4649
u_term = make_term.(u)
4750
jet = build_function(u_term, u0, t0; expression = Val(false), cse = true)
48-
if !haskey(JET_CACHE, f)
49-
JET_CACHE[f] = []
51+
if !haskey(JET_CACHE, key)
52+
JET_CACHE[key] = jet
5053
end
51-
push!(JET_CACHE[f], (order, p, jet))
5254
return jet
5355
end
5456

lib/OrdinaryDiffEqTaylorSeries/src/algorithms.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ end
2727
@doc explicit_rk_docstring(
2828
"An adaptive order explicit Taylor series method.",
2929
"ExplicitTaylorAdaptiveOrder")
30-
Base.@kwdef struct ExplicitTaylorAdaptiveOrder{StageLimiter, StepLimiter, Thread} <:
30+
Base.@kwdef struct ExplicitTaylorAdaptiveOrder{P, Q, StageLimiter, StepLimiter, Thread} <:
3131
OrdinaryDiffEqAdaptiveAlgorithm
32-
min_order::Int = 1
33-
max_order::Int = 10
32+
min_order::Val{P} = Val{1}()
33+
max_order::Val{Q} = Val{10}()
3434
stage_limiter!::StageLimiter = trivial_limiter!
3535
step_limiter!::StepLimiter = trivial_limiter!
3636
thread::Thread = False()

lib/OrdinaryDiffEqTaylorSeries/test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33

44
@testset "Taylor2 Convergence Tests" begin
55
# Test convergence
6-
dts = 2. .^ (-8:-4)
6+
dts = 2.0 .^ (-8:-4)
77
testTol = 0.2
88
sim = test_convergence(dts, prob_ode_linear, ExplicitTaylor2())
99
@test sim.𝒪est[:final]2 atol=testTol
@@ -13,10 +13,10 @@ end
1313

1414
@testset "Taylor Convergence Tests" begin
1515
# Test convergence
16-
dts = 2. .^ (-8:-4)
16+
dts = 2.0 .^ (-8:-4)
1717
testTol = 0.2
1818
for N in 3:4
19-
alg = ExplicitTaylor(order=Val(N))
19+
alg = ExplicitTaylor(order = Val(N))
2020
sim = test_convergence(dts, prob_ode_linear, alg)
2121
@test sim.𝒪est[:final]N atol=testTol
2222
sim = test_convergence(dts, prob_ode_2Dlinear, alg)
@@ -25,7 +25,7 @@ end
2525
end
2626

2727
@testset "Taylor Adaptive time-step Tests" begin
28-
sol = solve(prob_ode_linear, ExplicitTaylor(order=Val(4)))
28+
sol = solve(prob_ode_linear, ExplicitTaylor(order = Val(4)))
2929
@test length(sol) < 20
3030
end
3131

0 commit comments

Comments
 (0)