|
12 | 12 | @test_throws DomainError group_ranges(10, 1) |
13 | 13 | @test_throws DomainError group_ranges(10, 11) |
14 | 14 |
|
15 | | - ## Define initial conditions and time steps |
16 | | - datasize = 30 |
17 | | - u0 = Float32[2.0, 0.0] |
18 | | - tspan = (0.0f0, 5.0f0) |
19 | | - tsteps = range(tspan[1], tspan[2]; length = datasize) |
20 | | - |
21 | | - # Get the data |
22 | | - function trueODEfunc(du, u, p, t) |
23 | | - true_A = [-0.1 2.0; -2.0 -0.1] |
24 | | - du .= ((u .^ 3)'true_A)' |
| 15 | + # Test configurations |
| 16 | + test_configs = [ |
| 17 | + ( |
| 18 | + name = "Vector Test Config", |
| 19 | + u0 = Float32[2.0, 0.0], |
| 20 | + ode_func = (du, u, p, t) -> (du .= ((u .^ 3)'*[-0.1 2.0; -2.0 -0.1])'), |
| 21 | + nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2)), |
| 22 | + u0s_ensemble = [Float32[2.0, 0.0], Float32[3.0, 1.0]] |
| 23 | + ), |
| 24 | + ( |
| 25 | + name = "Multi-D Test Config", |
| 26 | + u0 = Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], |
| 27 | + ode_func = (du, u, p, t) -> (du .= ((u .^ 3).*[-0.01 0.02; -0.02 -0.01; 0.01 -0.05])), |
| 28 | + nn = Chain(x -> x .^ 3, Dense(3 => 3, tanh)), |
| 29 | + u0s_ensemble = [Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]] |
| 30 | + ) |
| 31 | + ] |
| 32 | + |
| 33 | + for config in test_configs |
| 34 | + @info "Running tests for: $(config.name)" |
| 35 | + |
| 36 | + ## Define initial conditions and time steps |
| 37 | + datasize = 30 |
| 38 | + u0 = config.u0 |
| 39 | + tspan = (0.0f0, 5.0f0) |
| 40 | + tsteps = range(tspan[1], tspan[2]; length = datasize) |
| 41 | + |
| 42 | + # Get the data |
| 43 | + trueODEfunc = config.ode_func |
| 44 | + prob_trueode = ODEProblem(trueODEfunc, u0, tspan) |
| 45 | + ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) |
| 46 | + |
| 47 | + # Define the Neural Network |
| 48 | + nn = config.nn |
| 49 | + p_init, st = Lux.setup(rng, nn) |
| 50 | + p_init = ComponentArray(p_init) |
| 51 | + |
| 52 | + neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) |
| 53 | + prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init) |
| 54 | + |
| 55 | + predict_single_shooting(p) = Array(first(neuralode(u0, p, st))) |
| 56 | + |
| 57 | + # Define loss function |
| 58 | + loss_function(data, pred) = sum(abs2, data - pred) |
| 59 | + |
| 60 | + ## Evaluate Single Shooting |
| 61 | + function loss_single_shooting(p) |
| 62 | + pred = predict_single_shooting(p) |
| 63 | + l = loss_function(ode_data, pred) |
| 64 | + return l |
| 65 | + end |
| 66 | + |
| 67 | + adtype = Optimization.AutoZygote() |
| 68 | + optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype) |
| 69 | + optprob = Optimization.OptimizationProblem(optf, p_init) |
| 70 | + res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
| 71 | + |
| 72 | + loss_ss = loss_single_shooting(res_single_shooting.minimizer) |
| 73 | + @info "Single shooting loss: $(loss_ss)" |
| 74 | + |
| 75 | + ## Test Multiple Shooting |
| 76 | + group_size = 3 |
| 77 | + continuity_term = 200 |
| 78 | + |
| 79 | + function loss_multiple_shooting(p) |
| 80 | + return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(), |
| 81 | + group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs |
| 82 | + end |
| 83 | + |
| 84 | + adtype = Optimization.AutoZygote() |
| 85 | + optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype) |
| 86 | + optprob = Optimization.OptimizationProblem(optf, p_init) |
| 87 | + res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
| 88 | + |
| 89 | + # Calculate single shooting loss with parameter from multiple_shoot training |
| 90 | + loss_ms = loss_single_shooting(res_ms.minimizer) |
| 91 | + println("Multiple shooting loss: $(loss_ms)") |
| 92 | + @test loss_ms < 10loss_ss |
| 93 | + |
| 94 | + # Test with custom loss function |
| 95 | + group_size = 4 |
| 96 | + continuity_term = 50 |
| 97 | + |
| 98 | + function continuity_loss_abs2(û_end, u_0) |
| 99 | + return sum(abs2, û_end - u_0) # using abs2 instead of default abs |
| 100 | + end |
| 101 | + |
| 102 | + function loss_multiple_shooting_abs2(p) |
| 103 | + return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, |
| 104 | + continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1] |
| 105 | + end |
| 106 | + |
| 107 | + adtype = Optimization.AutoZygote() |
| 108 | + optf = Optimization.OptimizationFunction( |
| 109 | + (p, _) -> loss_multiple_shooting_abs2(p), adtype) |
| 110 | + optprob = Optimization.OptimizationProblem(optf, p_init) |
| 111 | + res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
| 112 | + |
| 113 | + loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer) |
| 114 | + println("Multiple shooting loss with abs2: $(loss_ms_abs2)") |
| 115 | + @test loss_ms_abs2 < loss_ss |
| 116 | + |
| 117 | + ## Test different SensitivityAlgorithm (default is InterpolatingAdjoint) |
| 118 | + function loss_multiple_shooting_fd(p) |
| 119 | + return multiple_shoot( |
| 120 | + p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2, |
| 121 | + Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1] |
| 122 | + end |
| 123 | + |
| 124 | + adtype = Optimization.AutoZygote() |
| 125 | + optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype) |
| 126 | + optprob = Optimization.OptimizationProblem(optf, p_init) |
| 127 | + res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
| 128 | + |
| 129 | + # Calculate single shooting loss with parameter from multiple_shoot training |
| 130 | + loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer) |
| 131 | + println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)") |
| 132 | + @test loss_ms_fd < 10loss_ss |
| 133 | + |
| 134 | + # Integration return codes `!= :Success` should return infinite loss. |
| 135 | + # In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`. |
| 136 | + loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function, |
| 137 | + Tsit5(), datasize; maxiters = 1, verbose = false)[1] |
| 138 | + @test loss_fail == Inf |
| 139 | + |
| 140 | + ## Test for DomainErrors |
| 141 | + @test_throws DomainError multiple_shoot( |
| 142 | + p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), 1) |
| 143 | + @test_throws DomainError multiple_shoot( |
| 144 | + p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), datasize + 1) |
| 145 | + |
| 146 | + ## Ensembles |
| 147 | + u0s = config.u0s_ensemble |
| 148 | + function prob_func(prob, i, repeat) |
| 149 | + remake(prob; u0 = u0s[i]) |
| 150 | + end |
| 151 | + ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func) |
| 152 | + ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func) |
| 153 | + ensemble_alg = EnsembleThreads() |
| 154 | + trajectories = 2 |
| 155 | + ode_data_ensemble = Array(solve( |
| 156 | + ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, saveat = tsteps)) |
| 157 | + |
| 158 | + group_size = 3 |
| 159 | + continuity_term = 200 |
| 160 | + function loss_multiple_shooting_ens(p) |
| 161 | + return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, |
| 162 | + loss_function, Tsit5(), group_size; continuity_term, |
| 163 | + trajectories, abstol = 1e-8, reltol = 1e-6)[1] |
| 164 | + end |
| 165 | + |
| 166 | + adtype = Optimization.AutoZygote() |
| 167 | + optf = Optimization.OptimizationFunction( |
| 168 | + (p, _) -> loss_multiple_shooting_ens(p), adtype) |
| 169 | + optprob = Optimization.OptimizationProblem(optf, p_init) |
| 170 | + res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
| 171 | + |
| 172 | + loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer) |
| 173 | + |
| 174 | + println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)") |
| 175 | + |
| 176 | + @test loss_ms_ensembles < 10loss_ss |
25 | 177 | end |
26 | | - prob_trueode = ODEProblem(trueODEfunc, u0, tspan) |
27 | | - ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) |
28 | | - |
29 | | - # Define the Neural Network |
30 | | - nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2)) |
31 | | - p_init, st = Lux.setup(rng, nn) |
32 | | - p_init = ComponentArray(p_init) |
33 | | - |
34 | | - neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) |
35 | | - prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init) |
36 | | - |
37 | | - predict_single_shooting(p) = Array(first(neuralode(u0, p, st))) |
38 | | - |
39 | | - # Define loss function |
40 | | - loss_function(data, pred) = sum(abs2, data - pred) |
41 | | - |
42 | | - ## Evaluate Single Shooting |
43 | | - function loss_single_shooting(p) |
44 | | - pred = predict_single_shooting(p) |
45 | | - l = loss_function(ode_data, pred) |
46 | | - return l |
47 | | - end |
48 | | - |
49 | | - adtype = Optimization.AutoZygote() |
50 | | - optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype) |
51 | | - optprob = Optimization.OptimizationProblem(optf, p_init) |
52 | | - res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
53 | | - |
54 | | - loss_ss = loss_single_shooting(res_single_shooting.minimizer) |
55 | | - @info "Single shooting loss: $(loss_ss)" |
56 | | - |
57 | | - ## Test Multiple Shooting |
58 | | - group_size = 3 |
59 | | - continuity_term = 200 |
60 | | - |
61 | | - function loss_multiple_shooting(p) |
62 | | - return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(), |
63 | | - group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs |
64 | | - end |
65 | | - |
66 | | - adtype = Optimization.AutoZygote() |
67 | | - optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype) |
68 | | - optprob = Optimization.OptimizationProblem(optf, p_init) |
69 | | - res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
70 | | - |
71 | | - # Calculate single shooting loss with parameter from multiple_shoot training |
72 | | - loss_ms = loss_single_shooting(res_ms.minimizer) |
73 | | - println("Multiple shooting loss: $(loss_ms)") |
74 | | - @test loss_ms < 10loss_ss |
75 | | - |
76 | | - # Test with custom loss function |
77 | | - group_size = 4 |
78 | | - continuity_term = 50 |
79 | | - |
80 | | - function continuity_loss_abs2(û_end, u_0) |
81 | | - return sum(abs2, û_end - u_0) # using abs2 instead of default abs |
82 | | - end |
83 | | - |
84 | | - function loss_multiple_shooting_abs2(p) |
85 | | - return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, |
86 | | - continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1] |
87 | | - end |
88 | | - |
89 | | - adtype = Optimization.AutoZygote() |
90 | | - optf = Optimization.OptimizationFunction( |
91 | | - (p, _) -> loss_multiple_shooting_abs2(p), adtype) |
92 | | - optprob = Optimization.OptimizationProblem(optf, p_init) |
93 | | - res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
94 | | - |
95 | | - loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer) |
96 | | - println("Multiple shooting loss with abs2: $(loss_ms_abs2)") |
97 | | - @test loss_ms_abs2 < loss_ss |
98 | | - |
99 | | - ## Test different SensitivityAlgorithm (default is InterpolatingAdjoint) |
100 | | - function loss_multiple_shooting_fd(p) |
101 | | - return multiple_shoot( |
102 | | - p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2, |
103 | | - Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1] |
104 | | - end |
105 | | - |
106 | | - adtype = Optimization.AutoZygote() |
107 | | - optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype) |
108 | | - optprob = Optimization.OptimizationProblem(optf, p_init) |
109 | | - res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
110 | | - |
111 | | - # Calculate single shooting loss with parameter from multiple_shoot training |
112 | | - loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer) |
113 | | - println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)") |
114 | | - @test loss_ms_fd < 10loss_ss |
115 | | - |
116 | | - # Integration return codes `!= :Success` should return infinite loss. |
117 | | - # In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`. |
118 | | - loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function, |
119 | | - Tsit5(), datasize; maxiters = 1, verbose = false)[1] |
120 | | - @test loss_fail == Inf |
121 | | - |
122 | | - ## Test for DomainErrors |
123 | | - @test_throws DomainError multiple_shoot( |
124 | | - p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), 1) |
125 | | - @test_throws DomainError multiple_shoot( |
126 | | - p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), datasize + 1) |
127 | | - |
128 | | - ## Ensembles |
129 | | - u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]] |
130 | | - function prob_func(prob, i, repeat) |
131 | | - remake(prob; u0 = u0s[i]) |
132 | | - end |
133 | | - ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func) |
134 | | - ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func) |
135 | | - ensemble_alg = EnsembleThreads() |
136 | | - trajectories = 2 |
137 | | - ode_data_ensemble = Array(solve( |
138 | | - ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, saveat = tsteps)) |
139 | | - |
140 | | - group_size = 3 |
141 | | - continuity_term = 200 |
142 | | - function loss_multiple_shooting_ens(p) |
143 | | - return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, |
144 | | - loss_function, Tsit5(), group_size; continuity_term, |
145 | | - trajectories, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs |
146 | | - end |
147 | | - |
148 | | - adtype = Optimization.AutoZygote() |
149 | | - optf = Optimization.OptimizationFunction( |
150 | | - (p, _) -> loss_multiple_shooting_ens(p), adtype) |
151 | | - optprob = Optimization.OptimizationProblem(optf, p_init) |
152 | | - res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300) |
153 | | - |
154 | | - loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer) |
155 | | - |
156 | | - println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)") |
157 | | - |
158 | | - @test loss_ms_ensembles < 10loss_ss |
159 | 178 | end |
0 commit comments