@@ -64,12 +64,6 @@ Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
6464struct DeepONet
6565 branch_net:: Flux.Chain
6666 trunk_net:: Flux.Chain
67- # Constructor for the DeepONet
68- function DeepONet (
69- branch_net:: Flux.Chain ,
70- trunk_net:: Flux.Chain )
71- new (branch_net, trunk_net)
72- end
7367end
7468
7569# Declare the function that assigns Weights and biases to the layer
@@ -79,6 +73,10 @@ function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
7973 init_trunk = Flux. glorot_uniform,
8074 bias_branch= true , bias_trunk= true )
8175
76+ @assert architecture_branch[end ] == architecture_trunk[end ] " Branch and Trunk
77+ net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ
78+ won't work."
79+
8280 # To construct the subnets we use the helper function in subnets.jl
8381 # Initialize the branch net
8482 branch_net = construct_subnet (architecture_branch, act_branch;
@@ -103,7 +101,7 @@ function (a::DeepONet)(x::AbstractMatrix, y::AbstractVecOrMat)
103101 # However, inputs are normally given with batching done in the same dim
104102 # so we need to adjust (i.e. transpose) one of the inputs,
105103 # and that's easiest on the matrix-type input
106- return branch (x) * trunk (y)'
104+ return branch (x)' * trunk (y)
107105end
108106
109107# Handling batches:
0 commit comments