Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a8073de

Browse files
Merge pull request #67 from ven-k/vk/nomad
Add Nonlinear Manifold Decoders for Operator Learning (NOMAD)
2 parents 16d641f + d093556 commit a8073de

File tree

7 files changed

+163
-1
lines changed

7 files changed

+163
-1
lines changed

docs/src/introduction.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@ by linking the operators into a Markov chain.
2424
Deep operator network (DeepONet) learns a neural operator with the help of two sub-neural network structures described as the branch and the trunk network.
2525
The branch network is fed the initial conditions data, whereas the trunk is fed with the locations where the target(output) is evaluated from the corresponding initial conditions.
2626
It is important that the output size of the branch and trunk subnets is same so that a dot product can be performed between them.
27+
28+
## [Nonlinear Manifold Decoders for Operator Learning](https://github.com/SciML/NeuralOperators.jl/blob/master/src/NOMAD.jl)
29+
30+
Nonlinear Manifold Decoders for Operator Learning (NOMAD) learns a neural operator with a nonlinear decoder parameterized by a deep neural network which jointly takes output of approximator and the locations as parameters.
31+
The approximator network is fed with the initial conditions data. The output-of-approximator and the locations are then passed to a decoder neural network to get the target (output). It is important that the input size of the decoder subnet is sum of size of the output-of-approximator and number of locations.

example/Burgers/src/Burgers.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module Burgers
33
using DataDeps, MAT, MLUtils
44
using NeuralOperators, Flux
55
using CUDA, FluxTraining, BSON
6+
import Flux: params
67

78
include("Burgers_deeponet.jl")
89

@@ -69,4 +70,42 @@ function train(; cuda=true, η₀=1f-3, λ=1f-4, epochs=500)
6970
return learner
7071
end
7172

73+
function train_nomad(; n=300, cuda=true, learning_rate=0.001, epochs=400)
74+
if cuda && has_cuda()
75+
@info "Training on GPU"
76+
device = gpu
77+
else
78+
@info "Training on CPU"
79+
device = cpu
80+
end
81+
82+
x, y = get_data_don(n=n)
83+
84+
xtrain = x[1:280, :]'
85+
ytrain = y[1:280, :]
86+
87+
xval = x[end-19:end, :]' |> device
88+
yval = y[end-19:end, :] |> device
89+
90+
# grid = collect(range(0, 1, length=1024)') |> device
91+
grid = rand(collect(0:0.001:1), (280, 1024)) |> device
92+
gridval = rand(collect(0:0.001:1), (20, 1024)) |> device
93+
94+
95+
opt = ADAM(learning_rate)
96+
97+
m = NOMAD((1024,1024), (2048,1024), gelu, gelu) |> device
98+
99+
loss(X, y, sensor) = Flux.Losses.mse(m(X, sensor), y)
100+
evalcb() = @show(loss(xval, yval, gridval))
101+
102+
data = [(xtrain, ytrain, grid)] |> device
103+
Flux.@epochs epochs Flux.train!(loss, params(m), data, opt, cb=evalcb)
104+
= m(xval |> device, gridval |> device)
105+
106+
diffvec = vec(abs.(cpu(yval) .- cpu(ỹ)))
107+
mean_diff = sum(diffvec)/length(diffvec)
108+
return mean_diff
109+
end
110+
72111
end

example/Burgers/test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,8 @@ using Test
1313

1414
# include("deeponet.jl")
1515
end
16+
17+
@testset "Burger: NOMAD Training Accuracy" begin
18+
ϵ = Burgers.train_nomad(; cuda=true, epochs=100)
19+
@test ϵ < 0.4 # epoch=100 returns 0.233
20+
end

src/NOMAD.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
struct NOMAD{T1, T2}
2+
approximator_net::T1
3+
decoder_net::T2
4+
end
5+
6+
"""
7+
`NOMAD(architecture_approximator::Tuple, architecture_decoder::Tuple,
8+
act_approximator = identity, act_decoder=true;
9+
init_approximator = Flux.glorot_uniform,
10+
init_decoder = Flux.glorot_uniform,
11+
bias_approximator=true, bias_decoder=true)`
12+
`NOMAD(approximator_net::Flux.Chain, decoder_net::Flux.Chain)`
13+
14+
Create a Nonlinear Manifold Decoders for Operator Learning (NOMAD) as proposed by Lu et al.
15+
arXiv:2206.03551
16+
17+
The decoder is defined as follows:
18+
19+
``\\tilde D (β, y) = f(β, y)``
20+
21+
# Usage
22+
23+
```julia
24+
julia> model = NOMAD((16,32,16), (24,32))
25+
NOMAD with
26+
Approximator net: (Chain(Dense(16 => 32), Dense(32 => 16)))
27+
Decoder net: (Chain(Dense(24 => 32, true)))
28+
29+
julia> model = NeuralOperators.NOMAD((32,64,32), (64,72), σ, tanh; init_approximator=Flux.glorot_normal, bias_decoder=false)
30+
NOMAD with
31+
Approximator net: (Chain(Dense(32 => 64, σ), Dense(64 => 32, σ)))
32+
Decoder net: (Chain(Dense(64 => 72, tanh; bias=false)))
33+
34+
julia> approximator = Chain(Dense(2,128),Dense(128,64))
35+
Chain(
36+
Dense(2 => 128), # 384 parameters
37+
Dense(128 => 64), # 8_256 parameters
38+
) # Total: 4 arrays, 8_640 parameters, 34.000 KiB.
39+
40+
julia> decoder = Chain(Dense(72,24),Dense(24,12))
41+
Chain(
42+
Dense(72 => 24), # 1_752 parameters
43+
Dense(24 => 12), # 300 parameters
44+
) # Total: 4 arrays, 2_052 parameters, 8.266 KiB.
45+
46+
julia> model = NOMAD(approximator, decoder)
47+
NOMAD with
48+
Approximator net: (Chain(Dense(2 => 128), Dense(128 => 64)))
49+
Decoder net: (Chain(Dense(72 => 24), Dense(24 => 12)))
50+
"""
51+
function NOMAD(architecture_approximator::Tuple, architecture_decoder::Tuple,
52+
act_approximator = identity, act_decoder=true;
53+
init_approximator = Flux.glorot_uniform,
54+
init_decoder = Flux.glorot_uniform,
55+
bias_approximator=true, bias_decoder=true)
56+
57+
approximator_net = construct_subnet(architecture_approximator, act_approximator;
58+
init=init_approximator, bias=bias_approximator)
59+
60+
decoder_net = construct_subnet(architecture_decoder, act_decoder;
61+
init=init_decoder, bias=bias_decoder)
62+
63+
return NOMAD{typeof(approximator_net), typeof(decoder_net)}(approximator_net, decoder_net)
64+
end
65+
66+
Flux.@functor NOMAD
67+
68+
function (a::NOMAD)(x::AbstractArray, y::AbstractVecOrMat)
69+
# Assign the parameters
70+
approximator, decoder = a.approximator_net, a.decoder_net
71+
72+
return decoder(cat(approximator(x), y', dims=1))'
73+
end
74+
75+
# Print nicely
76+
function Base.show(io::IO, l::NOMAD)
77+
print(io, "NOMAD with\nApproximator net: (",l.approximator_net)
78+
print(io, ")\n")
79+
print(io, "Decoder net: (", l.decoder_net)
80+
print(io, ")\n")
81+
end

src/NeuralOperators.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ module NeuralOperators
1010
using GeometricFlux
1111
using Statistics
1212

13-
export DeepONet
13+
export DeepONet, NOMAD
1414

1515
include("Transform/Transform.jl")
1616
include("operator_kernel.jl")
1717
include("loss.jl")
1818
include("model.jl")
1919
include("DeepONet.jl")
20+
include("NOMAD.jl")
2021
include("subnets.jl")
2122
end

test/nomad.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
@testset "NOMAD" begin
2+
@testset "proper construction" begin
3+
nomad = NOMAD((32,64,72), (24,48,72), σ, tanh)
4+
# approximator net
5+
@test size(nomad.approximator_net.layers[end].weight) == (72,64)
6+
@test size(nomad.approximator_net.layers[end].bias) == (72,)
7+
# decoder net
8+
@test size(nomad.decoder_net.layers[end].weight) == (72,48)
9+
@test size(nomad.decoder_net.layers[end].bias) == (72,)
10+
end
11+
12+
# Accept only Int as architecture parameters
13+
@test_throws MethodError NOMAD((32.5,64,72), (24,48,72), σ, tanh)
14+
@test_throws MethodError NOMAD((32,64,72), (24.1,48,72))
15+
16+
# Just the first 16 datapoints from the Burgers' equation dataset
17+
a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755,
18+
0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651,
19+
0.81943734, 0.81737952, 0.8152405, 0.81302771]
20+
sensors = collect(range(0, 1, length=16)')
21+
model = NOMAD((length(a), 22, length(a)), (length(a) + length(sensors), length(sensors)), σ, tanh; init_approximator=Flux.glorot_normal, bias_decoder=false)
22+
y = model(a, sensors)
23+
@test size(y) == (1, 16)
24+
# Check if model description is printed, when defined
25+
@test repr(model) == "NOMAD with\nApproximator net: (Chain(Dense(16 => 22, σ), Dense(22 => 16, σ)))\nDecoder net: (Chain(Dense(32 => 16, tanh; bias=false)))\n"
26+
27+
mgrad = Flux.Zygote.gradient(() -> sum(model(a, sensors)), Flux.params(model))
28+
@info mgrad.grads
29+
@test length(mgrad.grads) == 5
30+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tests = [
1818
"loss.jl",
1919
"model.jl",
2020
"deeponet.jl",
21+
"nomad.jl",
2122
]
2223

2324
if CUDA.functional()

0 commit comments

Comments
 (0)