Skip to content

Commit 4fa134c

Browse files
committed
🆕 initial DeepONet implementation
1 parent ac0fd84 commit 4fa134c

File tree

3 files changed

+132
-1
lines changed

3 files changed

+132
-1
lines changed

‎src/DeepONet.jl‎

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
`DeepONet(in, out, grid, modes, σ=identity, init=glorot_uniform)`
3+
`DeepONet(Wf::AbstractArray, Wl::AbstractArray, [bias_f, bias_l, σ])`
4+
5+
Create a DeepONet architecture as proposed by Lu et al.
6+
arXiv:1910.03193
7+
8+
The model works as follows:
9+
10+
x --- branch --
11+
|
12+
-⊠--u-
13+
|
14+
y --- trunk ---
15+
16+
Where `x` represent the parameters of the PDE, discretely evaluated at its respective sensors,
17+
and `y` are the probing locations for the operator to be trained.
18+
`u` is the solution of the queried instance of the PDE, given by the specific choice of parameters.
19+
20+
Both inputs `x` and `y` are multiplied together via dot product Σᵢ bᵢⱼ tᵢₖ.
21+
22+
```julia
23+
model = DeepONet()
24+
```
25+
"""
26+
struct DeepONet
27+
branch_net::Flux.Chain
28+
trunk_net::Flux.Chain
29+
# Constructor for the DeepONet
30+
function DeepONet(
31+
branch_net::Flux.Chain,
32+
trunk_net::Flux.Chain)
33+
new(branch_net, trunk_net)
34+
end
35+
end
36+
37+
# Declare the function that assigns Weights and biases to the layer
38+
function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
39+
act_branch = identity, act_trunk = identity;
40+
init = Flux.glorot_uniform,
41+
bias_branch=true, bias_trunk=true)
42+
43+
# To construct the subnets we use the helper function in subnets.jl
44+
# Initialize the branch net
45+
branch_net = construct_subnet(architecture_branch, act_branch; bias=bias_branch)
46+
# Initialize the trunk net
47+
trunk_net = construct_subnet(architecture_trunk, act_trunk; bias=bias_trunk)
48+
49+
return DeepONet(branch_net, trunk_net)
50+
end
51+
52+
Flux.@functor DeepONet
53+
54+
# The actual layer that does stuff
55+
# x needs to be at least a 2-dim array,
56+
# since we need n inputs, evaluated at m locations
57+
function (a::DeepONet)(x::AbstractMatrix, y::AbstractVecOrMat)
58+
# Assign the parameters
59+
branch, trunk = a.branch_net, a.trunk_net
60+
61+
# Dot product needs a dim to contract
62+
# However, inputs are normally given with batching done in the same dim
63+
# so we need to adjust (i.e. transpose) one of the inputs,
64+
# and that's easiest on the matrix-type input
65+
return branch(x) * trunk(y)'
66+
end
67+
68+
# Handling batches:
69+
# We use basically the same function, but using NNlib's batched_mul instead of
70+
# regular matrix-matrix multiplication
71+
function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
72+
# Assign the parameters
73+
branch, trunk = a.branch_net, a.trunk_net
74+
75+
# Dot product needs a dim to contract
76+
# However, inputs are normally given with batching done in the same dim
77+
# so we need to adjust (i.e. transpose) one of the inputs,
78+
# and that's easiest on the matrix-type input
79+
return branch(x) ⊠ trunk(y)'
80+
end
81+
82+
# Sensors stay the same and shouldn't be batched
83+
(a::DeepONet)(x::AbstractArray, y::AbstractArray) =
84+
throw(ArgumentError("Sensor locations fed to trunk net can't be batched."))
85+
86+
# Print nicely
87+
function Base.show(io::IO, l::DeepONet)
88+
print(io, "DeepONet with\nbranch net: (",l.branch_net)
89+
print(io, ")\n")
90+
print(io, "Trunk net: (", l.trunk_net)
91+
print(io, ")\n")
92+
end

‎src/OperatorLearning.jl‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ using Random: AbstractRNG
1010
using Flux: nfan, glorot_uniform, batch
1111
using OMEinsum
1212

13-
export FourierLayer
13+
export FourierLayer, DeepONet
1414

1515
include("FourierLayer.jl")
16+
include("DeepONet.jl")
1617
include("ComplexWeights.jl")
1718
include("batched.jl")
19+
include("subnets.jl")
1820

1921
end # module

‎src/subnets.jl‎

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
Construct a Chain of `Dense` layers from a given tuple of integers.
3+
4+
Input:
5+
A tuple (m,n,o,p) of integer type numbers that each describe the width of the i-th Dense layer to Construct
6+
7+
Output:
8+
A `Flux` Chain with length of the input tuple and individual width given by the tuple elements
9+
10+
# Example
11+
12+
```julia
13+
julia> model = OperatorLearning.construct_subnet((2,128,64,32,1))
14+
Chain(
15+
Dense(2, 128), # 384 parameters
16+
Dense(128, 64), # 8_256 parameters
17+
Dense(64, 32), # 2_080 parameters
18+
Dense(32, 1), # 33 parameters
19+
) # Total: 8 arrays, 10_753 parameters, 42.504 KiB.
20+
21+
julia> model([2,1])
22+
1-element Vector{Float32}:
23+
-0.7630446
24+
```
25+
"""
26+
function construct_subnet(architecture::Tuple, σ = identity; bias=true)
27+
# First, create an array that contains all Dense layers independently
28+
# Given n-element architecture constructs n-1 layers
29+
layers = Array{Flux.Dense}(undef, length(architecture)-1)
30+
@inbounds for i ∈ 2:length(architecture)
31+
layers[i-1] = Flux.Dense(architecture[i-1], architecture[i], σ; bias=bias)
32+
end
33+
34+
# Concatenate the layers to a string, chain them and parse them into
35+
# the Flux Chain constructor syntax
36+
return Meta.parse("Chain("*join(layers,",")*")") |> eval
37+
end

0 commit comments

Comments
 (0)