Skip to content

Commit f117f80

Browse files
Merge pull request #974 from SCiarella/master
Extend multiple_shoot loss to multidimensional NeuralODEs
2 parents 0eb457d + f32a32d commit f117f80

File tree

2 files changed

+177
-155
lines changed

2 files changed

+177
-155
lines changed

src/multiple_shooting.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ Arguments:
3636
function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
3737
continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm,
3838
group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C}
39-
datasize = size(ode_data, 2)
39+
datasize = size(ode_data, ndims(ode_data))
40+
griddims = ntuple(_ -> Colon(), ndims(ode_data) - 1)
4041

4142
if group_size < 2 || group_size > datasize
4243
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
@@ -48,7 +49,7 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
4849
# Multiple shooting predictions
4950
sols = [solve(
5051
remake(prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]),
51-
u0 = ode_data[:, first(rg)]),
52+
u0 = ode_data[griddims..., first(rg)]),
5253
solver;
5354
saveat = tsteps[rg],
5455
kwargs...) for rg in ranges]
@@ -61,15 +62,15 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
6162
# Calculate multiple shooting loss
6263
loss = 0
6364
for (i, rg) in enumerate(ranges)
64-
u = ode_data[:, rg]
65-
= group_predictions[i]
65+
u = ode_data[griddims..., rg]
66+
= group_predictions[i][griddims..., :]
6667
loss += loss_function(u, û)
6768

6869
if i > 1
6970
# Ensure continuity between last state in previous prediction
7071
# and current initial condition in ode_data
7172
loss += continuity_term *
72-
continuity_loss(group_predictions[i - 1][:, end], u[:, 1])
73+
continuity_loss(group_predictions[i - 1][griddims..., end], u[griddims..., 1])
7374
end
7475
end
7576

@@ -121,16 +122,18 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
121122
ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F,
122123
continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm,
123124
group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C}
124-
datasize = size(ode_data, 2)
125+
ntraj = size(ode_data, ndims(ode_data))
126+
datasize = size(ode_data, ndims(ode_data)-1)
127+
griddims = ntuple(_ -> Colon(), ndims(ode_data) - 2)
125128
prob = ensembleprob.prob
126129

127130
if group_size < 2 || group_size > datasize
128131
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
129132
end
130133

131-
@assert ndims(ode_data)==3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
132-
@assert size(ode_data, 2) == length(tsteps)
133-
@assert size(ode_data, 3) == kwargs[:trajectories]
134+
@assert ndims(ode_data)>=3 "ode_data must have at least three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
135+
@assert datasize == length(tsteps)
136+
@assert ntraj == kwargs[:trajectories]
134137

135138
# Get ranges that partition data to groups of size group_size
136139
ranges = group_ranges(datasize, group_size)
@@ -140,7 +143,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
140143
rg -> begin
141144
newprob = remake(prob; p = p, tspan = (tsteps[first(rg)], tsteps[last(rg)]))
142145
function prob_func(prob, i, repeat)
143-
remake(prob; u0 = ode_data[:, first(rg), i])
146+
remake(prob; u0 = ode_data[griddims..., first(rg), i])
144147
end
145148
newensembleprob = EnsembleProblem(
146149
newprob, prob_func, ensembleprob.output_func, ensembleprob.reduction,
@@ -158,7 +161,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
158161
loss = 0
159162
for (i, rg) in enumerate(ranges)
160163
= group_predictions[i]
161-
u = ode_data[:, rg, :] # trajectories are at dims 3
164+
u = ode_data[griddims..., rg, :] # trajectories are at dims 3
162165
# just summing up losses for all trajectories
163166
# but other alternatives might be considered
164167

@@ -168,7 +171,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
168171
# Ensure continuity between last state in previous prediction
169172
# and current initial condition in ode_data
170173
loss += continuity_term *
171-
continuity_loss(group_predictions[i - 1][:, end, :], u[:, 1, :])
174+
continuity_loss(group_predictions[i - 1][griddims..., end, :], u[griddims..., 1, :])
172175
end
173176
end
174177

test/multiple_shoot_tests.jl

Lines changed: 162 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -12,148 +12,167 @@
1212
@test_throws DomainError group_ranges(10, 1)
1313
@test_throws DomainError group_ranges(10, 11)
1414

15-
## Define initial conditions and time steps
16-
datasize = 30
17-
u0 = Float32[2.0, 0.0]
18-
tspan = (0.0f0, 5.0f0)
19-
tsteps = range(tspan[1], tspan[2]; length = datasize)
20-
21-
# Get the data
22-
function trueODEfunc(du, u, p, t)
23-
true_A = [-0.1 2.0; -2.0 -0.1]
24-
du .= ((u .^ 3)'true_A)'
15+
# Test configurations
16+
test_configs = [
17+
(
18+
name = "Vector Test Config",
19+
u0 = Float32[2.0, 0.0],
20+
ode_func = (du, u, p, t) -> (du .= ((u .^ 3)'*[-0.1 2.0; -2.0 -0.1])'),
21+
nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2)),
22+
u0s_ensemble = [Float32[2.0, 0.0], Float32[3.0, 1.0]]
23+
),
24+
(
25+
name = "Multi-D Test Config",
26+
u0 = Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0],
27+
ode_func = (du, u, p, t) -> (du .= ((u .^ 3).*[-0.01 0.02; -0.02 -0.01; 0.01 -0.05])),
28+
nn = Chain(x -> x .^ 3, Dense(3 => 3, tanh)),
29+
u0s_ensemble = [Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]]
30+
)
31+
]
32+
33+
for config in test_configs
34+
@info "Running tests for: $(config.name)"
35+
36+
## Define initial conditions and time steps
37+
datasize = 30
38+
u0 = config.u0
39+
tspan = (0.0f0, 5.0f0)
40+
tsteps = range(tspan[1], tspan[2]; length = datasize)
41+
42+
# Get the data
43+
trueODEfunc = config.ode_func
44+
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
45+
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))
46+
47+
# Define the Neural Network
48+
nn = config.nn
49+
p_init, st = Lux.setup(rng, nn)
50+
p_init = ComponentArray(p_init)
51+
52+
neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps)
53+
prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init)
54+
55+
predict_single_shooting(p) = Array(first(neuralode(u0, p, st)))
56+
57+
# Define loss function
58+
loss_function(data, pred) = sum(abs2, data - pred)
59+
60+
## Evaluate Single Shooting
61+
function loss_single_shooting(p)
62+
pred = predict_single_shooting(p)
63+
l = loss_function(ode_data, pred)
64+
return l
65+
end
66+
67+
adtype = Optimization.AutoZygote()
68+
optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype)
69+
optprob = Optimization.OptimizationProblem(optf, p_init)
70+
res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
71+
72+
loss_ss = loss_single_shooting(res_single_shooting.minimizer)
73+
@info "Single shooting loss: $(loss_ss)"
74+
75+
## Test Multiple Shooting
76+
group_size = 3
77+
continuity_term = 200
78+
79+
function loss_multiple_shooting(p)
80+
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
81+
group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs
82+
end
83+
84+
adtype = Optimization.AutoZygote()
85+
optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype)
86+
optprob = Optimization.OptimizationProblem(optf, p_init)
87+
res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
88+
89+
# Calculate single shooting loss with parameter from multiple_shoot training
90+
loss_ms = loss_single_shooting(res_ms.minimizer)
91+
println("Multiple shooting loss: $(loss_ms)")
92+
@test loss_ms < 10loss_ss
93+
94+
# Test with custom loss function
95+
group_size = 4
96+
continuity_term = 50
97+
98+
function continuity_loss_abs2(û_end, u_0)
99+
return sum(abs2, û_end - u_0) # using abs2 instead of default abs
100+
end
101+
102+
function loss_multiple_shooting_abs2(p)
103+
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function,
104+
continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1]
105+
end
106+
107+
adtype = Optimization.AutoZygote()
108+
optf = Optimization.OptimizationFunction(
109+
(p, _) -> loss_multiple_shooting_abs2(p), adtype)
110+
optprob = Optimization.OptimizationProblem(optf, p_init)
111+
res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
112+
113+
loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer)
114+
println("Multiple shooting loss with abs2: $(loss_ms_abs2)")
115+
@test loss_ms_abs2 < loss_ss
116+
117+
## Test different SensitivityAlgorithm (default is InterpolatingAdjoint)
118+
function loss_multiple_shooting_fd(p)
119+
return multiple_shoot(
120+
p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2,
121+
Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1]
122+
end
123+
124+
adtype = Optimization.AutoZygote()
125+
optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype)
126+
optprob = Optimization.OptimizationProblem(optf, p_init)
127+
res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
128+
129+
# Calculate single shooting loss with parameter from multiple_shoot training
130+
loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer)
131+
println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)")
132+
@test loss_ms_fd < 10loss_ss
133+
134+
# Integration return codes `!= :Success` should return infinite loss.
135+
# In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`.
136+
loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function,
137+
Tsit5(), datasize; maxiters = 1, verbose = false)[1]
138+
@test loss_fail == Inf
139+
140+
## Test for DomainErrors
141+
@test_throws DomainError multiple_shoot(
142+
p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), 1)
143+
@test_throws DomainError multiple_shoot(
144+
p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), datasize + 1)
145+
146+
## Ensembles
147+
u0s = config.u0s_ensemble
148+
function prob_func(prob, i, repeat)
149+
remake(prob; u0 = u0s[i])
150+
end
151+
ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func)
152+
ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func)
153+
ensemble_alg = EnsembleThreads()
154+
trajectories = 2
155+
ode_data_ensemble = Array(solve(
156+
ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, saveat = tsteps))
157+
158+
group_size = 3
159+
continuity_term = 200
160+
function loss_multiple_shooting_ens(p)
161+
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
162+
loss_function, Tsit5(), group_size; continuity_term,
163+
trajectories, abstol = 1e-8, reltol = 1e-6)[1]
164+
end
165+
166+
adtype = Optimization.AutoZygote()
167+
optf = Optimization.OptimizationFunction(
168+
(p, _) -> loss_multiple_shooting_ens(p), adtype)
169+
optprob = Optimization.OptimizationProblem(optf, p_init)
170+
res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
171+
172+
loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer)
173+
174+
println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)")
175+
176+
@test loss_ms_ensembles < 10loss_ss
25177
end
26-
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
27-
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))
28-
29-
# Define the Neural Network
30-
nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2))
31-
p_init, st = Lux.setup(rng, nn)
32-
p_init = ComponentArray(p_init)
33-
34-
neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps)
35-
prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init)
36-
37-
predict_single_shooting(p) = Array(first(neuralode(u0, p, st)))
38-
39-
# Define loss function
40-
loss_function(data, pred) = sum(abs2, data - pred)
41-
42-
## Evaluate Single Shooting
43-
function loss_single_shooting(p)
44-
pred = predict_single_shooting(p)
45-
l = loss_function(ode_data, pred)
46-
return l
47-
end
48-
49-
adtype = Optimization.AutoZygote()
50-
optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype)
51-
optprob = Optimization.OptimizationProblem(optf, p_init)
52-
res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
53-
54-
loss_ss = loss_single_shooting(res_single_shooting.minimizer)
55-
@info "Single shooting loss: $(loss_ss)"
56-
57-
## Test Multiple Shooting
58-
group_size = 3
59-
continuity_term = 200
60-
61-
function loss_multiple_shooting(p)
62-
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
63-
group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs
64-
end
65-
66-
adtype = Optimization.AutoZygote()
67-
optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype)
68-
optprob = Optimization.OptimizationProblem(optf, p_init)
69-
res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
70-
71-
# Calculate single shooting loss with parameter from multiple_shoot training
72-
loss_ms = loss_single_shooting(res_ms.minimizer)
73-
println("Multiple shooting loss: $(loss_ms)")
74-
@test loss_ms < 10loss_ss
75-
76-
# Test with custom loss function
77-
group_size = 4
78-
continuity_term = 50
79-
80-
function continuity_loss_abs2(û_end, u_0)
81-
return sum(abs2, û_end - u_0) # using abs2 instead of default abs
82-
end
83-
84-
function loss_multiple_shooting_abs2(p)
85-
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function,
86-
continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1]
87-
end
88-
89-
adtype = Optimization.AutoZygote()
90-
optf = Optimization.OptimizationFunction(
91-
(p, _) -> loss_multiple_shooting_abs2(p), adtype)
92-
optprob = Optimization.OptimizationProblem(optf, p_init)
93-
res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
94-
95-
loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer)
96-
println("Multiple shooting loss with abs2: $(loss_ms_abs2)")
97-
@test loss_ms_abs2 < loss_ss
98-
99-
## Test different SensitivityAlgorithm (default is InterpolatingAdjoint)
100-
function loss_multiple_shooting_fd(p)
101-
return multiple_shoot(
102-
p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2,
103-
Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1]
104-
end
105-
106-
adtype = Optimization.AutoZygote()
107-
optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype)
108-
optprob = Optimization.OptimizationProblem(optf, p_init)
109-
res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
110-
111-
# Calculate single shooting loss with parameter from multiple_shoot training
112-
loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer)
113-
println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)")
114-
@test loss_ms_fd < 10loss_ss
115-
116-
# Integration return codes `!= :Success` should return infinite loss.
117-
# In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`.
118-
loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function,
119-
Tsit5(), datasize; maxiters = 1, verbose = false)[1]
120-
@test loss_fail == Inf
121-
122-
## Test for DomainErrors
123-
@test_throws DomainError multiple_shoot(
124-
p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), 1)
125-
@test_throws DomainError multiple_shoot(
126-
p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), datasize + 1)
127-
128-
## Ensembles
129-
u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]]
130-
function prob_func(prob, i, repeat)
131-
remake(prob; u0 = u0s[i])
132-
end
133-
ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func)
134-
ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func)
135-
ensemble_alg = EnsembleThreads()
136-
trajectories = 2
137-
ode_data_ensemble = Array(solve(
138-
ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, saveat = tsteps))
139-
140-
group_size = 3
141-
continuity_term = 200
142-
function loss_multiple_shooting_ens(p)
143-
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
144-
loss_function, Tsit5(), group_size; continuity_term,
145-
trajectories, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs
146-
end
147-
148-
adtype = Optimization.AutoZygote()
149-
optf = Optimization.OptimizationFunction(
150-
(p, _) -> loss_multiple_shooting_ens(p), adtype)
151-
optprob = Optimization.OptimizationProblem(optf, p_init)
152-
res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300)
153-
154-
loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer)
155-
156-
println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)")
157-
158-
@test loss_ms_ensembles < 10loss_ss
159178
end

0 commit comments

Comments
 (0)