Skip to content

Commit 31fc2e1

Browse files
committed
🚫 get rid of DeepONet constructor
1 parent ad8c34a commit 31fc2e1

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

‎src/DeepONet.jl‎

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,6 @@ Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
6464
struct DeepONet
6565
branch_net::Flux.Chain
6666
trunk_net::Flux.Chain
67-
# Constructor for the DeepONet
68-
function DeepONet(
69-
branch_net::Flux.Chain,
70-
trunk_net::Flux.Chain)
71-
new(branch_net, trunk_net)
72-
end
7367
end
7468

7569
# Declare the function that assigns Weights and biases to the layer
@@ -79,6 +73,10 @@ function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
7973
init_trunk = Flux.glorot_uniform,
8074
bias_branch=true, bias_trunk=true)
8175

76+
@assert architecture_branch[end] == architecture_trunk[end] "Branch and Trunk
77+
net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ
78+
won't work."
79+
8280
# To construct the subnets we use the helper function in subnets.jl
8381
# Initialize the branch net
8482
branch_net = construct_subnet(architecture_branch, act_branch;
@@ -103,7 +101,7 @@ function (a::DeepONet)(x::AbstractMatrix, y::AbstractVecOrMat)
103101
# However, inputs are normally given with batching done in the same dim
104102
# so we need to adjust (i.e. transpose) one of the inputs,
105103
# and that's easiest on the matrix-type input
106-
return branch(x) * trunk(y)'
104+
return branch(x)' * trunk(y)
107105
end
108106

109107
# Handling batches:

0 commit comments

Comments
 (0)