Skip to content

Commit 0d9fe6b

Browse files
committed
πŸ§‘β€πŸ”§ fix dot product, expand docstring
1 parent 31fc2e1 commit 0d9fe6b

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

β€Žsrc/DeepONet.jlβ€Ž

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ x --- branch --
1717
|
1818
y --- trunk ---
1919
20-
Where `x` represent the parameters of the PDE, discretely evaluated at its respective sensors,
21-
and `y` are the probing locations for the operator to be trained.
20+
Where `x` represents the input function, discretely evaluated at its respective sensors. So the ipnut is of shape [m] for one instance or [m x b] for a training set.
21+
`y` are the probing locations for the operator to be trained. It has shape [N x n] for N different variables in the PDE (i.e. spatial and temporal coordinates) with each n distinct evaluation points.
2222
`u` is the solution of the queried instance of the PDE, given by the specific choice of parameters.
2323
2424
Both inputs `x` and `y` are multiplied together via dot product Ξ£α΅’ bα΅’β±Ό tα΅’β‚–.
@@ -31,6 +31,15 @@ You can set up this architecture in two ways:
3131
3232
Strictly speaking, DeepONet does not imply either of the branch or trunk net to be a simple DNN. Usually though, this is the case which is why it's treated as the default case here.
3333
34+
# Example
35+
36+
Consider a transient 1D advection problem βˆ‚β‚œu + u β‹… βˆ‡u = 0, with an IC u(x,0) = g(x).
37+
We are given several (b = 200) instances of the IC, discretized at 50 points each and want to query the solution for 100 different locations and times [0;1].
38+
39+
That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100].
40+
41+
# Usage
42+
3443
```julia
3544
julia> model = DeepONet((32,64,72), (24,64,72))
3645
DeepONet with
@@ -73,9 +82,7 @@ function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
7382
init_trunk = Flux.glorot_uniform,
7483
bias_branch=true, bias_trunk=true)
7584

76-
@assert architecture_branch[end] == architecture_trunk[end] "Branch and Trunk
77-
net must share the same amount of nodes in the last layer. Otherwise Ξ£α΅’ bα΅’β±Ό tα΅’β‚–
78-
won't work."
85+
@assert architecture_branch[end] == architecture_trunk[end] "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Ξ£α΅’ bα΅’β±Ό tα΅’β‚– won't work."
7986

8087
# To construct the subnets we use the helper function in subnets.jl
8188
# Initialize the branch net
@@ -90,34 +97,21 @@ end
9097

9198
Flux.@functor DeepONet
9299

93-
# The actual layer that does stuff
94-
# x needs to be at least a 2-dim array,
95-
# since we need n inputs, evaluated at m locations
96-
function (a::DeepONet)(x::AbstractMatrix, y::AbstractVecOrMat)
100+
#= The actual layer that does stuff
101+
x is the input function, evaluated at m locations (or m x b in case of batches)
102+
y is the array of sensors, i.e. the variables of the output function
103+
with shape (N x n) - N different variables with each n evaluation points =#
104+
function (a::DeepONet)(x::AbstractVecOrMat, y::AbstractVecOrMat)
97105
# Assign the parameters
98106
branch, trunk = a.branch_net, a.trunk_net
99107

100-
# Dot product needs a dim to contract
101-
# However, inputs are normally given with batching done in the same dim
102-
# so we need to adjust (i.e. transpose) one of the inputs,
103-
# and that's easiest on the matrix-type input
108+
#= Dot product needs a dim to contract
109+
However, we perform the transformations by the NNs always in the first dim
110+
so we need to adjust (i.e. transpose) one of the inputs,
111+
which we do on the branch input here =#
104112
return branch(x)' * trunk(y)
105113
end
106114

107-
# Handling batches:
108-
# We use basically the same function, but using NNlib's batched_mul instead of
109-
# regular matrix-matrix multiplication
110-
function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
111-
# Assign the parameters
112-
branch, trunk = a.branch_net, a.trunk_net
113-
114-
# Dot product needs a dim to contract
115-
# However, inputs are normally given with batching done in the same dim
116-
# so we need to adjust (i.e. transpose) one of the inputs,
117-
# and that's easiest on the matrix-type input
118-
return branch(x) ⊠ trunk(y)'
119-
end
120-
121115
# Sensors stay the same and shouldn't be batched
122116
(a::DeepONet)(x::AbstractArray, y::AbstractArray) =
123117
throw(ArgumentError("Sensor locations fed to trunk net can't be batched."))

0 commit comments

Comments
Β (0)