Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 1221dc9

Browse files
committed
Refactor project
1 parent 61d32af commit 1221dc9

File tree

8 files changed

+69
-62
lines changed

8 files changed

+69
-62
lines changed

src/DeepONet.jl renamed to src/DeepONet/DeepONet.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
export DeepONet
2+
3+
include("subnets.jl")
4+
15
"""
26
`DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
37
act_branch = identity, act_trunk = identity;
File renamed without changes.
File renamed without changes.

src/NOMAD.jl renamed to src/NOMAD/NOMAD.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
export NOMAD
2+
13
struct NOMAD{T1, T2}
24
approximator_net::T1
35
decoder_net::T2

src/NeuralOperators.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
module NeuralOperators
2+
23
using Flux
34
using FFTW
45
using Tullio
@@ -10,15 +11,15 @@ using ChainRulesCore
1011
using GeometricFlux
1112
using Statistics
1213

13-
export DeepONet
14-
export NOMAD
15-
14+
# kernels
1615
include("Transform/Transform.jl")
1716
include("operator_kernel.jl")
17+
include("graph_kernel.jl")
1818
include("loss.jl")
19-
include("model.jl")
20-
include("DeepONet.jl")
21-
include("subnets.jl")
22-
include("NOMAD.jl")
2319

24-
end
20+
# models
21+
include("FNO/FNO.jl")
22+
include("DeepONet/DeepONet.jl")
23+
include("NOMAD/NOMAD.jl")
24+
25+
end # module

src/graph_kernel.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
export GraphKernel
2+
3+
"""
4+
GraphKernel(κ, ch, σ=identity)
5+
6+
Graph kernel layer.
7+
8+
## Arguments
9+
10+
* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP.
11+
* `ch`: Channel size for linear transform, e.g. `32`.
12+
* `σ`: Activation function.
13+
14+
## Keyword Arguments
15+
16+
* `init`: Initial function to initialize parameters.
17+
"""
18+
struct GraphKernel{A, B, F} <: MessagePassing
19+
linear::A
20+
κ::B
21+
σ::F
22+
end
23+
24+
function GraphKernel(κ, ch::Int, σ = identity; init = Flux.glorot_uniform)
25+
W = init(ch, ch)
26+
return GraphKernel(W, κ, σ)
27+
end
28+
29+
Flux.@functor GraphKernel
30+
31+
function GeometricFlux.message(l::GraphKernel, x_i::AbstractArray, x_j::AbstractArray, e_ij)
32+
return l.κ(vcat(x_i, x_j))
33+
end
34+
35+
function GeometricFlux.update(l::GraphKernel, m::AbstractArray, x::AbstractArray)
36+
return l.σ.(GeometricFlux._matmul(l.linear, x) + m)
37+
end
38+
39+
function (l::GraphKernel)(el::NamedTuple, X::AbstractArray)
40+
GraphSignals.check_num_nodes(el.N, X)
41+
_, V, _ = GeometricFlux.propagate(l, el, nothing, X, nothing, mean, nothing, nothing)
42+
return V
43+
end
44+
45+
function Base.show(io::IO, l::GraphKernel)
46+
channel, _ = size(l.linear)
47+
print(io, "GraphKernel(", l.κ, ", channel=", channel)
48+
l.σ == identity || print(io, ", ", l.σ)
49+
print(io, ")")
50+
end

src/operator_kernel.jl

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
export
22
OperatorConv,
33
SpectralConv,
4-
OperatorKernel,
5-
GraphKernel
4+
OperatorKernel
65

76
struct OperatorConv{P, T, S, TT}
87
weight::T
@@ -66,8 +65,8 @@ function SpectralConv(ch::Pair{S, S},
6665
init = c_glorot_uniform,
6766
permuted = false,
6867
T::DataType = ComplexF32) where {S <: Integer, N}
69-
return OperatorConv(ch, modes, FourierTransform, init = init, permuted = permuted,
70-
T = T)
68+
return OperatorConv(ch, modes, FourierTransform,
69+
init = init, permuted = permuted, T = T)
7170
end
7271

7372
Flux.@functor OperatorConv{true}
@@ -181,55 +180,6 @@ function (m::OperatorKernel)(𝐱)
181180
return m.σ.(m.linear(𝐱) + m.conv(𝐱))
182181
end
183182

184-
"""
185-
GraphKernel(κ, ch, σ=identity)
186-
187-
Graph kernel layer.
188-
189-
## Arguments
190-
191-
* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP.
192-
* `ch`: Channel size for linear transform, e.g. `32`.
193-
* `σ`: Activation function.
194-
195-
## Keyword Arguments
196-
197-
* `init`: Initial function to initialize parameters.
198-
"""
199-
struct GraphKernel{A, B, F} <: MessagePassing
200-
linear::A
201-
κ::B
202-
σ::F
203-
end
204-
205-
function GraphKernel(κ, ch::Int, σ = identity; init = Flux.glorot_uniform)
206-
W = init(ch, ch)
207-
return GraphKernel(W, κ, σ)
208-
end
209-
210-
Flux.@functor GraphKernel
211-
212-
function GeometricFlux.message(l::GraphKernel, x_i::AbstractArray, x_j::AbstractArray, e_ij)
213-
return l.κ(vcat(x_i, x_j))
214-
end
215-
216-
function GeometricFlux.update(l::GraphKernel, m::AbstractArray, x::AbstractArray)
217-
return l.σ.(GeometricFlux._matmul(l.linear, x) + m)
218-
end
219-
220-
function (l::GraphKernel)(el::NamedTuple, X::AbstractArray)
221-
GraphSignals.check_num_nodes(el.N, X)
222-
_, V, _ = GeometricFlux.propagate(l, el, nothing, X, nothing, mean, nothing, nothing)
223-
return V
224-
end
225-
226-
function Base.show(io::IO, l::GraphKernel)
227-
channel, _ = size(l.linear)
228-
print(io, "GraphKernel(", l.κ, ", channel=", channel)
229-
l.σ == identity || print(io, ", ", l.σ)
230-
print(io, ")")
231-
end
232-
233183
#########
234184
# utils #
235185
#########

test/nomad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@
2828
"NOMAD with\nApproximator net: (Chain(Dense(16 => 22, σ), Dense(22 => 16, σ)))\nDecoder net: (Chain(Dense(32 => 16, tanh; bias=false)))\n"
2929

3030
mgrad = Flux.Zygote.gradient(() -> sum(model(a, sensors)), Flux.params(model))
31-
@info mgrad.grads
31+
# @info mgrad.grads
3232
@test length(mgrad.grads) == 5
3333
end

0 commit comments

Comments
 (0)