11"""
2- DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
3- branch_activation = identity, trunk_activation = identity)
2+ DeepONet(branch, trunk, additional)
43
5- Constructs a DeepONet composed of Dense layers . Make sure the last node of `branch` and
6- `trunk` are same.
4+ Constructs a DeepONet from a `branch` and `trunk` architectures . Make sure that both the
5+ nets output should have the same first dimension .
76
8- ## Keyword arguments:
7+ ## Arguments
8+
9+ - `branch`: `Lux` network to be used as branch net.
10+ - `trunk`: `Lux` network to be used as trunk net.
11+
12+ ## Keyword Arguments
913
10- - `branch`: Tuple of integers containing the number of nodes in each layer for branch net
11- - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
12- - `branch_activation`: activation function for branch net
13- - `trunk_activation`: activation function for trunk net
1414 - `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
1515 for embeddings, defaults to `nothing`
1616
@@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193
2323## Example
2424
2525```jldoctest
26- julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
26+ julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
27+
28+ julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
29+
30+ julia> deeponet = DeepONet(branch_net, trunk_net);
2731
2832julia> ps, st = Lux.setup(Xoshiro(), deeponet);
2933
@@ -35,37 +39,27 @@ julia> size(first(deeponet((u, y), ps, st)))
3539(10, 5)
3640```
3741"""
38- function DeepONet (;
39- branch= (64 , 32 , 32 , 16 ), trunk= (1 , 8 , 8 , 16 ), branch_activation= identity,
40- trunk_activation= identity, additional= nothing )
41-
42- # checks for last dimension size
43- @argcheck branch[end ]== trunk[end ] " Branch and Trunk net must share the same amount of \
44- nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
45- work."
46-
47- branch_net = Chain ([Dense (branch[i] => branch[i + 1 ], branch_activation)
48- for i in 1 : (length (branch) - 1 )]. .. )
49-
50- trunk_net = Chain ([Dense (trunk[i] => trunk[i + 1 ], trunk_activation)
51- for i in 1 : (length (trunk) - 1 )]. .. )
52-
53- return DeepONet (branch_net, trunk_net; additional)
42+ @concrete struct DeepONet <: AbstractExplicitContainerLayer{(:branch, :trunk, :additional)}
43+ branch
44+ trunk
45+ additional
5446end
5547
56- """
57- DeepONet(branch, trunk)
48+ DeepONet (branch, trunk) = DeepONet (branch, trunk, NoOpLayer ())
5849
59- Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
60- nets output should have the same first dimension.
61-
62- ## Arguments
50+ """
51+ DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
52+ branch_activation = identity, trunk_activation = identity)
6353
64- - `branch`: `Lux` network to be used as branch net.
65- - `trunk`: `Lux` network to be used as trunk net .
54+ Constructs a DeepONet composed of Dense layers. Make sure the last node of ` branch` and
55+ `trunk` are same .
6656
67- ## Keyword Arguments
57+ ## Keyword arguments:
6858
59+ - `branch`: Tuple of integers containing the number of nodes in each layer for branch net
60+ - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
61+ - `branch_activation`: activation function for branch net
62+ - `trunk_activation`: activation function for trunk net
6963 - `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
7064 for embeddings, defaults to `nothing`
7165
@@ -78,11 +72,7 @@ operators", doi: https://arxiv.org/abs/1910.03193
7872## Example
7973
8074```jldoctest
81- julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
82-
83- julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
84-
85- julia> deeponet = DeepONet(branch_net, trunk_net);
75+ julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
8676
8777julia> ps, st = Lux.setup(Xoshiro(), deeponet);
8878
@@ -94,15 +84,32 @@ julia> size(first(deeponet((u, y), ps, st)))
9484(10, 5)
9585```
9686"""
97- function DeepONet (branch:: L1 , trunk:: L2 ; additional= nothing ) where {L1, L2}
98- return @compact (; branch, trunk, additional, dispatch= :DeepONet ) do (u, y)
99- t = trunk (y) # p x N x nb
100- b = branch (u) # p x u_size... x nb
87+ function DeepONet (;
88+ branch= (64 , 32 , 32 , 16 ), trunk= (1 , 8 , 8 , 16 ), branch_activation= identity,
89+ trunk_activation= identity, additional= NoOpLayer ())
90+
91+ # checks for last dimension size
92+ @argcheck branch[end ]== trunk[end ] " Branch and Trunk net must share the same amount of \
93+ nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
94+ work."
95+
96+ branch_net = Chain ([Dense (branch[i] => branch[i + 1 ], branch_activation)
97+ for i in 1 : (length (branch) - 1 )]. .. )
98+
99+ trunk_net = Chain ([Dense (trunk[i] => trunk[i + 1 ], trunk_activation)
100+ for i in 1 : (length (trunk) - 1 )]. .. )
101+
102+ return DeepONet (branch_net, trunk_net, additional)
103+ end
104+
105+ function (deeponet:: DeepONet )(x, ps, st:: NamedTuple )
106+ b, st_b = deeponet. branch (x[1 ], ps. branch, st. branch)
107+ t, st_t = deeponet. trunk (x[2 ], ps. trunk, st. trunk)
101108
102- @argcheck size (t , 1 )== size (b , 1 ) " Branch and Trunk net must share the same \
103- amount of nodes in the last layer. Otherwise \
104- Σᵢ bᵢⱼ tᵢₖ won't work."
109+ @argcheck size (b , 1 )== size (t , 1 ) " Branch and Trunk net must share the same amount of \
110+ nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
111+ work."
105112
106- @return __project (b, t, additional)
107- end
113+ out, st_a = __project (b, t, deeponet . additional, (; ps = ps . additional, st = st . additional) )
114+ return out, (branch = st_b, trunk = st_t, additional = st_a)
108115end
0 commit comments