Skip to content

Commit cc55450

Browse files
Merge pull request #987 from SciML/avik-pal-patch-1
fix: force a recent version of Lux to avoid ForwardDiff regression
2 parents 694c3fd + 4a72140 commit cc55450

File tree

6 files changed

+38
-44
lines changed

6 files changed

+38
-44
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@ steps:
2323
- label: "Documentation"
2424
plugins:
2525
- JuliaCI/julia#v1:
26-
version: "1.10"
26+
version: "1"
2727
command: |
2828
julia --project -e '
2929
println("--- :julia: Instantiating project")
3030
using Pkg
3131
Pkg.instantiate()
3232
Pkg.activate("docs")
33-
Pkg.develop(PackageSpec(path=pwd()))
3433
Pkg.instantiate()
3534
push!(LOAD_PATH, @__DIR__)
3635
println("+++ :julia: Building documentation")

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqFlux"
22
uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "4.4.0"
4+
version = "4.4.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -30,7 +30,7 @@ DiffEqFluxDataInterpolationsExt = "DataInterpolations"
3030
ADTypes = "1.5"
3131
Aqua = "0.8.7"
3232
BenchmarkTools = "1.5.0"
33-
Boltz = "1"
33+
Boltz = "1.7"
3434
ChainRulesCore = "1"
3535
ComponentArrays = "0.15.17"
3636
ConcreteStructs = "0.2"
@@ -43,11 +43,11 @@ Distributions = "0.25"
4343
DistributionsAD = "0.6.55"
4444
ExplicitImports = "1.9"
4545
Flux = "0.16"
46-
ForwardDiff = "0.10"
46+
ForwardDiff = "0.10, 1"
4747
Hwloc = "3"
4848
InteractiveUtils = "<0.0.1, 1"
4949
LinearAlgebra = "1.10"
50-
Lux = "1"
50+
Lux = "1.22"
5151
LuxCUDA = "0.3.2"
5252
LuxCore = "1"
5353
LuxLib = "1.2"

docs/Project.toml

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
33
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
44
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
5-
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
65
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
76
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
87
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
98
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
109
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
11-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
10+
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
1211
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1312
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1413
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -27,27 +26,24 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2726
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2827
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2928
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
30-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3129
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3230
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
33-
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3431
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3532
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
36-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3733
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3834

35+
[sources]
36+
DiffEqFlux = {path = ".."}
37+
3938
[compat]
40-
CSV = "0.10"
4139
CUDA = "5"
4240
ComponentArrays = "0.15"
43-
DataDeps = "0.7"
4441
DataFrames = "1"
4542
DiffEqFlux = "4"
4643
Distances = "0.10.7"
4744
Distributions = "0.25.78"
4845
Documenter = "1"
49-
Flux = "0.14, 0.15, 0.16"
50-
ForwardDiff = "0.10"
46+
ForwardDiff = "0.10, 1"
5147
IterTools = "1"
5248
LinearAlgebra = "1"
5349
Lux = "1"
@@ -65,11 +61,8 @@ OrdinaryDiffEq = "6.31"
6561
Plots = "1.36"
6662
Printf = "1"
6763
Random = "1"
68-
ReverseDiff = "1.14"
6964
SciMLBase = "2"
7065
SciMLSensitivity = "7.11"
71-
StableRNGs = "1"
7266
Statistics = "1"
7367
StochasticDiffEq = "6.56"
74-
Test = "1"
7568
Zygote = "0.6.62, 0.7"

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ makedocs(; sitename = "DiffEqFlux.jl",
1313
authors = "Chris Rackauckas et al.",
1414
clean = true,
1515
doctest = false,
16-
linkcheck = true,
17-
warnonly = [:docs_block, :missing_docs],
16+
# linkcheck = true,
17+
warnonly = [:docs_block, :missing_docs, :linkcheck],
1818
modules = [DiffEqFlux],
1919
format = Documenter.HTML(; assets = ["assets/favicon.ico"],
2020
canonical = "https://docs.sciml.ai/DiffEqFlux/stable/"),

docs/src/examples/hamiltonian_nn.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,19 @@ dpdt = -2π_32 .* q_t
2626
data = cat(q_t, p_t; dims = 1)
2727
target = cat(dqdt, dpdt; dims = 1)
2828
B = 256
29-
NEPOCHS = 500
29+
NEPOCHS = 125
3030
dataloader = DataLoader((data, target); batchsize = B)
3131
32-
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote())
32+
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (32, 32, 1), gelu); autodiff = AutoZygote())
3333
ps, st = Lux.setup(Xoshiro(0), hnn)
34+
model = StatefulLuxLayer(hnn, ps, st)
3435
ps_c = ps |> ComponentArray
3536
36-
opt = OptimizationOptimisers.Adam(0.01f0)
37+
opt = OptimizationOptimisers.Adam(0.003f0)
3738
3839
function loss_function(ps, databatch)
3940
data, target = databatch
40-
pred, st_ = hnn(data, ps, st)
41+
pred = model(data, ps)
4142
return mean(abs2, pred .- target)
4243
end
4344
@@ -53,10 +54,10 @@ res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS)
5354
5455
ps_trained = res.u
5556
56-
model = NeuralODE(
57+
nhde = NeuralODE(
5758
hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, save_start = true, saveat = t)
5859
59-
pred = Array(first(model(data[:, 1], ps_trained, st)))
60+
pred = Array(first(nhde(data[:, 1], ps_trained, st)))
6061
plot(data[1, :], data[2, :]; lw = 4, label = "Original")
6162
plot!(pred[1, :], pred[2, :]; lw = 4, label = "Predicted")
6263
xlabel!("Position (q)")
@@ -69,7 +70,7 @@ ylabel!("Momentum (p)")
6970

7071
The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we generate the pairs ``(q, p)`` using the equations given at the top. Additionally, to supervise the training, we also generate the gradients. Next, we use Flux DataLoader for automatically batching our dataset.
7172

72-
```@example hamiltonian
73+
```julia
7374
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
7475
ComponentArrays, Optimization, OptimizationOptimisers, MLUtils
7576

@@ -83,25 +84,25 @@ dpdt = -2π_32 .* q_t
8384
data = cat(q_t, p_t; dims = 1)
8485
target = cat(dqdt, dpdt; dims = 1)
8586
B = 256
86-
NEPOCHS = 500
87+
NEPOCHS = 125
8788
dataloader = DataLoader((data, target); batchsize = B)
8889
```
8990

9091
### Training the HamiltonianNN
9192

9293
We parameterize the with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization.
9394

94-
```@example hamiltonian
95-
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote())
95+
```julia
96+
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (32, 32, 1), gelu); autodiff = AutoZygote())
9697
ps, st = Lux.setup(Xoshiro(0), hnn)
98+
model = StatefulLuxLayer(hnn, ps, st)
9799
ps_c = ps |> ComponentArray
98-
hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st)
99100

100-
opt = OptimizationOptimisers.Adam(0.005f0)
101+
opt = OptimizationOptimisers.Adam(0.003f0)
101102

102103
function loss_function(ps, databatch)
103-
(data, target) = databatch
104-
pred = hnn_stateful(data, ps)
104+
data, target = databatch
105+
pred = model(data, ps)
105106
return mean(abs2, pred .- target)
106107
end
107108

@@ -110,7 +111,7 @@ function callback(state, loss)
110111
return false
111112
end
112113

113-
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
114+
opt_func = OptimizationFunction(loss_function, Optimization.AutoForwardDiff())
114115
opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)
115116

116117
res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS)
@@ -123,11 +124,11 @@ ps_trained = res.u
123124
In order to visualize the learned trajectories, we need to solve the ODE. We will use the
124125
`NeuralODE` layer with `HamiltonianNN` layer, and solves the ODE.
125126

126-
```@example hamiltonian
127-
model = NeuralODE(
127+
```julia
128+
nhde = NeuralODE(
128129
hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, save_start = true, saveat = t)
129130

130-
pred = Array(first(model(data[:, 1], ps_trained, st)))
131+
pred = Array(first(nhde(data[:, 1], ps_trained, st)))
131132
plot(data[1, :], data[2, :]; lw = 4, label = "Original")
132133
plot!(pred[1, :], pred[2, :]; lw = 4, label = "Predicted")
133134
xlabel!("Position (q)")

docs/src/examples/neural_ode_weather_forecast.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@ The data is a four-dimensional dataset of daily temperature, humidity, wind spee
99

1010
```@example weather_forecast
1111
using Random, Dates, Optimization, ComponentArrays, Lux, OptimizationOptimisers, DiffEqFlux,
12-
OrdinaryDiffEq, CSV, DataFrames, Dates, Statistics, Plots, DataDeps
12+
OrdinaryDiffEq, CSV, DataFrames, Dates, Statistics, Plots
13+
using Downloads: download
1314
1415
function download_data(
1516
data_url = "https://raw.githubusercontent.com/SebastianCallh/neural-ode-weather-forecast/master/data/",
1617
data_local_path = "./delhi")
1718
function load(file_name)
18-
data_dep = DataDep("delhi/train", "", "$data_url/$file_name")
19-
Base.download(data_dep, data_local_path; i_accept_the_terms_of_use = true)
20-
CSV.read(joinpath(data_local_path, file_name), DataFrame)
19+
download("$data_url/$file_name", joinpath(data_local_path, file_name))
20+
return CSV.read(joinpath(data_local_path, file_name), DataFrame)
2121
end
2222
23+
mkpath(data_local_path)
2324
train_df = load("DailyDelhiClimateTrain.csv")
2425
test_df = load("DailyDelhiClimateTest.csv")
2526
return vcat(train_df, test_df)
@@ -102,7 +103,7 @@ We are now ready to construct and train our model! To avoid local minimas we wil
102103
function neural_ode(t, data_dim)
103104
f = Chain(Dense(data_dim => 64, swish), Dense(64 => 32, swish), Dense(32 => data_dim))
104105
105-
node = NeuralODE(f, extrema(t), Tsit5(); saveat = t, abstol = 1e-9, reltol = 1e-9)
106+
node = NeuralODE(f, extrema(t), Tsit5(); saveat = t, abstol = 1e-6, reltol = 1e-3)
106107
107108
rng = Xoshiro(0)
108109
p, state = Lux.setup(rng, f)
@@ -151,7 +152,7 @@ ps, state, losses = train(t_train, y_train, obs_grid, maxiters, lr, rng; progres
151152
We can now animate the training to get a better understanding of the fit.
152153

153154
```@example weather_forecast
154-
predict(y0, t, p, state) = begin
155+
function predict(y0, t, p, state)
155156
node, _, _ = neural_ode(t, length(y0))
156157
Array(node(y0, p, state)[1])
157158
end

0 commit comments

Comments
 (0)