Skip to content

Commit ad8c34a

Browse files
committed
✍️ expand docstring
1 parent 42a7147 commit ad8c34a

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

src/DeepONet.jl

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
"""
2-
`DeepONet(in, out, grid, modes, σ=identity, init=glorot_uniform)`
3-
`DeepONet(Wf::AbstractArray, Wl::AbstractArray, [bias_f, bias_l, σ])`
2+
`DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
3+
act_branch = identity, act_trunk = identity;
4+
init_branch = Flux.glorot_uniform,
5+
init_trunk = Flux.glorot_uniform,
6+
bias_branch=true, bias_trunk=true)`
7+
`DeepONet(branch_net::Flux.Chain, trunk_net::Flux.Chain)`
48
5-
Create a DeepONet architecture as proposed by Lu et al.
9+
Create an (unstacked) DeepONet architecture as proposed by Lu et al.
610
arXiv:1910.03193
711
812
The model works as follows:
@@ -19,8 +23,42 @@ and `y` are the probing locations for the operator to be trained.
1923
2024
Both inputs `x` and `y` are multiplied together via dot product Σᵢ bᵢⱼ tᵢₖ.
2125
26+
You can set up this architecture in two ways:
27+
28+
1. By Specifying the architecture and all its parameters as given above. This always creates `Dense` layers for the branch and trunk net and corresponds to the DeepONet proposed by Lu et al.
29+
30+
2. By passing two architectures in the form of two Chain structs directly. Do this if you want more flexibility and e.g. use an RNN or CNN instead of simple `Dense` layers.
31+
32+
Strictly speaking, DeepONet does not imply either of the branch or trunk net to be a simple DNN. Usually though, this is the case which is why it's treated as the default case here.
33+
2234
```julia
23-
model = DeepONet()
35+
julia> model = DeepONet((32,64,72), (24,64,72))
36+
DeepONet with
37+
branch net: (Chain(Dense(32, 64), Dense(64, 72)))
38+
Trunk net: (Chain(Dense(24, 64), Dense(64, 72)))
39+
40+
julia> model = DeepONet((32,64,72), (24,64,72), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)
41+
DeepONet with
42+
branch net: (Chain(Dense(32, 64, σ), Dense(64, 72, σ)))
43+
Trunk net: (Chain(Dense(24, 64, tanh; bias=false), Dense(64, 72, tanh; bias=false)))
44+
45+
julia> branch = Chain(Dense(2,128),Dense(128,64),Dense(64,72))
46+
Chain(
47+
Dense(2, 128), # 384 parameters
48+
Dense(128, 64), # 8_256 parameters
49+
Dense(64, 72), # 4_680 parameters
50+
) # Total: 6 arrays, 13_320 parameters, 52.406 KiB.
51+
52+
julia> trunk = Chain(Dense(1,24),Dense(24,72))
53+
Chain(
54+
Dense(1, 24), # 48 parameters
55+
Dense(24, 72), # 1_800 parameters
56+
) # Total: 4 arrays, 1_848 parameters, 7.469 KiB.
57+
58+
julia> model = DeepONet(branch,trunk)
59+
DeepONet with
60+
branch net: (Chain(Dense(2, 128), Dense(128, 64), Dense(64, 72)))
61+
Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
2462
```
2563
"""
2664
struct DeepONet
@@ -37,14 +75,17 @@ end
3775
# Declare the function that assigns Weights and biases to the layer
3876
function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
3977
act_branch = identity, act_trunk = identity;
40-
init = Flux.glorot_uniform,
78+
init_branch = Flux.glorot_uniform,
79+
init_trunk = Flux.glorot_uniform,
4180
bias_branch=true, bias_trunk=true)
4281

4382
# To construct the subnets we use the helper function in subnets.jl
4483
# Initialize the branch net
45-
branch_net = construct_subnet(architecture_branch, act_branch; bias=bias_branch)
84+
branch_net = construct_subnet(architecture_branch, act_branch;
85+
init=init_branch, bias=bias_branch)
4686
# Initialize the trunk net
47-
trunk_net = construct_subnet(architecture_trunk, act_trunk; bias=bias_trunk)
87+
trunk_net = construct_subnet(architecture_trunk, act_trunk;
88+
init=init_trunk, bias=bias_trunk)
4889

4990
return DeepONet(branch_net, trunk_net)
5091
end

src/subnets.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ julia> model([2,1])
2323
-0.7630446
2424
```
2525
"""
26-
function construct_subnet(architecture::Tuple, σ = identity; bias=true)
26+
function construct_subnet(architecture::Tuple, σ = identity;
27+
init=Flux.glorot_uniform, bias=true)
2728
# First, create an array that contains all Dense layers independently
2829
# Given n-element architecture constructs n-1 layers
2930
layers = Array{Flux.Dense}(undef, length(architecture)-1)
3031
@inbounds for i 2:length(architecture)
31-
layers[i-1] = Flux.Dense(architecture[i-1], architecture[i], σ; bias=bias)
32+
layers[i-1] = Flux.Dense(architecture[i-1], architecture[i], σ;
33+
init=init, bias=bias)
3234
end
3335

3436
# Concatenate the layers to a string, chain them and parse them into

0 commit comments

Comments
 (0)