You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/DeepONet.jl
+21-27Lines changed: 21 additions & 27 deletions
Original file line number
Diff line number
Diff line change
@@ -17,8 +17,8 @@ x --- branch --
17
17
|
18
18
y --- trunk ---
19
19
20
-
Where `x` represent the parameters of the PDE, discretely evaluated at its respective sensors,
21
-
and `y` are the probing locations for the operator to be trained.
20
+
Where `x` represents the input function, discretely evaluated at its respective sensors. So the ipnut is of shape [m] for one instance or [m x b] for a training set.
21
+
`y` are the probing locations for the operator to be trained. It has shape [N x n] for N different variables in the PDE (i.e. spatial and temporal coordinates) with each n distinct evaluation points.
22
22
`u` is the solution of the queried instance of the PDE, given by the specific choice of parameters.
23
23
24
24
Both inputs `x` and `y` are multiplied together via dot product Ξ£α΅’ bα΅’β±Ό tα΅’β.
@@ -31,6 +31,15 @@ You can set up this architecture in two ways:
31
31
32
32
Strictly speaking, DeepONet does not imply either of the branch or trunk net to be a simple DNN. Usually though, this is the case which is why it's treated as the default case here.
33
33
34
+
# Example
35
+
36
+
Consider a transient 1D advection problem ββu + u β βu = 0, with an IC u(x,0) = g(x).
37
+
We are given several (b = 200) instances of the IC, discretized at 50 points each and want to query the solution for 100 different locations and times [0;1].
38
+
39
+
That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100].
40
+
41
+
# Usage
42
+
34
43
```julia
35
44
julia> model = DeepONet((32,64,72), (24,64,72))
36
45
DeepONet with
@@ -73,9 +82,7 @@ function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
73
82
init_trunk = Flux.glorot_uniform,
74
83
bias_branch=true, bias_trunk=true)
75
84
76
-
@assert architecture_branch[end] == architecture_trunk[end] "Branch and Trunk
77
-
net must share the same amount of nodes in the last layer. Otherwise Ξ£α΅’ bα΅’β±Ό tα΅’β
78
-
won't work."
85
+
@assert architecture_branch[end] == architecture_trunk[end] "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Ξ£α΅’ bα΅’β±Ό tα΅’β won't work."
79
86
80
87
# To construct the subnets we use the helper function in subnets.jl
81
88
# Initialize the branch net
@@ -90,34 +97,21 @@ end
90
97
91
98
Flux.@functor DeepONet
92
99
93
-
# The actual layer that does stuff
94
-
# x needs to be at least a 2-dim array,
95
-
# since we need n inputs, evaluated at m locations
96
-
function (a::DeepONet)(x::AbstractMatrix, y::AbstractVecOrMat)
100
+
#= The actual layer that does stuff
101
+
x is the input function, evaluated at m locations (or m x b in case of batches)
102
+
y is the array of sensors, i.e. the variables of the output function
103
+
with shape (N x n) - N different variables with each n evaluation points =#
104
+
function (a::DeepONet)(x::AbstractVecOrMat, y::AbstractVecOrMat)
97
105
# Assign the parameters
98
106
branch, trunk = a.branch_net, a.trunk_net
99
107
100
-
# Dot product needs a dim to contract
101
-
#However, inputs are normally given with batching done in the same dim
102
-
#so we need to adjust (i.e. transpose) one of the inputs,
103
-
# and that's easiest on the matrix-type input
108
+
#= Dot product needs a dim to contract
109
+
However, we perform the transformations by the NNs always in the first dim
110
+
so we need to adjust (i.e. transpose) one of the inputs,
111
+
which we do on the branch input here =#
104
112
returnbranch(x)'*trunk(y)
105
113
end
106
114
107
-
# Handling batches:
108
-
# We use basically the same function, but using NNlib's batched_mul instead of
109
-
# regular matrix-matrix multiplication
110
-
function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
111
-
# Assign the parameters
112
-
branch, trunk = a.branch_net, a.trunk_net
113
-
114
-
# Dot product needs a dim to contract
115
-
# However, inputs are normally given with batching done in the same dim
116
-
# so we need to adjust (i.e. transpose) one of the inputs,
0 commit comments