1+ """
2+ `DeepONet(in, out, grid, modes, σ=identity, init=glorot_uniform)`
3+ `DeepONet(Wf::AbstractArray, Wl::AbstractArray, [bias_f, bias_l, σ])`
4+
5+ Create a DeepONet architecture as proposed by Lu et al.
6+ arXiv:1910.03193
7+
8+ The model works as follows:
9+
10+ x --- branch --
11+ |
12+ -⊠--u-
13+ |
14+ y --- trunk ---
15+
16+ Where `x` represent the parameters of the PDE, discretely evaluated at its respective sensors,
17+ and `y` are the probing locations for the operator to be trained.
18+ `u` is the solution of the queried instance of the PDE, given by the specific choice of parameters.
19+
20+ Both inputs `x` and `y` are multiplied together via dot product Σᵢ bᵢⱼ tᵢₖ.
21+
22+ ```julia
23+ model = DeepONet()
24+ ```
25+ """
26+ struct DeepONet
27+ branch_net:: Flux.Chain
28+ trunk_net:: Flux.Chain
29+ # Constructor for the DeepONet
30+ function DeepONet (
31+ branch_net:: Flux.Chain ,
32+ trunk_net:: Flux.Chain )
33+ new (branch_net, trunk_net)
34+ end
35+ end
36+
37+ # Declare the function that assigns Weights and biases to the layer
38+ function DeepONet (architecture_branch:: Tuple , architecture_trunk:: Tuple ,
39+ act_branch = identity, act_trunk = identity;
40+ init = Flux. glorot_uniform,
41+ bias_branch= true , bias_trunk= true )
42+
43+ # To construct the subnets we use the helper function in subnets.jl
44+ # Initialize the branch net
45+ branch_net = construct_subnet (architecture_branch, act_branch; bias= bias_branch)
46+ # Initialize the trunk net
47+ trunk_net = construct_subnet (architecture_trunk, act_trunk; bias= bias_trunk)
48+
49+ return DeepONet (branch_net, trunk_net)
50+ end
51+
52+ Flux. @functor DeepONet
53+
54+ # The actual layer that does stuff
55+ # x needs to be at least a 2-dim array,
56+ # since we need n inputs, evaluated at m locations
57+ function (a:: DeepONet )(x:: AbstractMatrix , y:: AbstractVecOrMat )
58+ # Assign the parameters
59+ branch, trunk = a. branch_net, a. trunk_net
60+
61+ # Dot product needs a dim to contract
62+ # However, inputs are normally given with batching done in the same dim
63+ # so we need to adjust (i.e. transpose) one of the inputs,
64+ # and that's easiest on the matrix-type input
65+ return branch (x) * trunk (y)'
66+ end
67+
68+ # Handling batches:
69+ # We use basically the same function, but using NNlib's batched_mul instead of
70+ # regular matrix-matrix multiplication
71+ function (a:: DeepONet )(x:: AbstractArray , y:: AbstractVecOrMat )
72+ # Assign the parameters
73+ branch, trunk = a. branch_net, a. trunk_net
74+
75+ # Dot product needs a dim to contract
76+ # However, inputs are normally given with batching done in the same dim
77+ # so we need to adjust (i.e. transpose) one of the inputs,
78+ # and that's easiest on the matrix-type input
79+ return branch (x) ⊠trunk (y)'
80+ end
81+
82+ # Sensors stay the same and shouldn't be batched
83+ (a:: DeepONet )(x:: AbstractArray , y:: AbstractArray ) =
84+ throw (ArgumentError (" Sensor locations fed to trunk net can't be batched." ))
85+
86+ # Print nicely
87+ function Base. show (io:: IO , l:: DeepONet )
88+ print (io, " DeepONet with\n branch net: (" ,l. branch_net)
89+ print (io, " )\n " )
90+ print (io, " Trunk net: (" , l. trunk_net)
91+ print (io, " )\n " )
92+ end
0 commit comments