Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 61d32af

Browse files
authored
Merge pull request #77 from SciML/refactor
Enhance docstring
2 parents 78446c4 + 9917824 commit 61d32af

File tree

5 files changed

+216
-34
lines changed

5 files changed

+216
-34
lines changed

docs/src/apis.md

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# APIs
22

3+
## Transforms
4+
5+
```@docs
6+
AbstractTransform
7+
```
8+
39
## Layers
410

511
### Operator convolutional layer
@@ -11,8 +17,9 @@ v'(x) = \mathcal{F}^{-1} \{ F'(s) \}
1117
```
1218

1319
where ``v(x)`` and ``v'(x)`` denotes input and output function,
14-
``\mathcal{F} \{ \cdot \}``, ``\mathcal{F}^{-1} \{ \cdot \}`` are Fourier transform, inverse Fourier transform, respectively.
15-
Function ``g`` is a linear transform for lowering Fouier modes.
20+
``\mathcal{F} \{ \cdot \}``, ``\mathcal{F}^{-1} \{ \cdot \}`` are transform,
21+
inverse transform, respectively.
22+
Function ``g`` is a linear transform for lowering spectrum modes.
1623

1724
```@docs
1825
OperatorConv
@@ -28,7 +35,8 @@ Reference: [FNO2021](@cite)
2835
v_{t+1}(x) = \sigma(W v_t(x) + \mathcal{K} \{ v_t(x) \} )
2936
```
3037

31-
where ``v_t(x)`` is the input function for ``t``-th layer and ``\mathcal{K} \{ \cdot \}`` denotes spectral convolutional layer.
38+
where ``v_t(x)`` is the input function for ``t``-th layer and
39+
``\mathcal{K} \{ \cdot \}`` denotes spectral convolutional layer.
3240
Activation function ``\sigma`` can be arbitrary non-linear function.
3341

3442
```@docs
@@ -45,7 +53,9 @@ Reference: [FNO2021](@cite)
4553
v_{t+1}(x_i) = \sigma(W v_t(x_i) + \frac{1}{|\mathcal{N}(x_i)|} \sum_{x_j \in \mathcal{N}(x_i)} \kappa \{ v_t(x_i), v_t(x_j) \} )
4654
```
4755

48-
where ``v_t(x_i)`` is the input function for ``t``-th layer, ``x_i`` is the node feature for ``i``-th node and ``\mathcal{N}(x_i)`` represents the neighbors for ``x_i``.
56+
where ``v_t(x_i)`` is the input function for ``t``-th layer,
57+
``x_i`` is the node feature for ``i``-th node and
58+
``\mathcal{N}(x_i)`` represents the neighbors for ``x_i``.
4959
Activation function ``\sigma`` can be arbitrary non-linear function.
5060

5161
```@docs
@@ -75,3 +85,22 @@ MarkovNeuralOperator
7585
```
7686

7787
Reference: [MNO2021](@cite)
88+
89+
---
90+
91+
### DeepONet
92+
93+
```@docs
94+
DeepONet
95+
NeuralOperators.construct_subnet
96+
```
97+
98+
---
99+
100+
### NOMAD
101+
102+
Nonlinear manifold decoders for operator learning
103+
104+
```@docs
105+
NOMAD
106+
```

src/NOMAD.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ struct NOMAD{T1, T2}
44
end
55

66
"""
7-
`NOMAD(architecture_approximator::Tuple, architecture_decoder::Tuple,
8-
act_approximator = identity, act_decoder=true;
9-
init_approximator = Flux.glorot_uniform,
10-
init_decoder = Flux.glorot_uniform,
11-
bias_approximator=true, bias_decoder=true)`
12-
`NOMAD(approximator_net::Flux.Chain, decoder_net::Flux.Chain)`
7+
NOMAD(architecture_approximator::Tuple, architecture_decoder::Tuple,
8+
act_approximator = identity, act_decoder=true;
9+
init_approximator = Flux.glorot_uniform,
10+
init_decoder = Flux.glorot_uniform,
11+
bias_approximator=true, bias_decoder=true)
12+
NOMAD(approximator_net::Flux.Chain, decoder_net::Flux.Chain)
1313
1414
Create a Nonlinear Manifold Decoders for Operator Learning (NOMAD) as proposed by Lu et al.
1515
arXiv:2206.03551
@@ -47,6 +47,7 @@ julia> model = NOMAD(approximator, decoder)
4747
NOMAD with
4848
Approximator net: (Chain(Dense(2 => 128), Dense(128 => 64)))
4949
Decoder net: (Chain(Dense(72 => 24), Dense(24 => 12)))
50+
```
5051
"""
5152
function NOMAD(architecture_approximator::Tuple, architecture_decoder::Tuple,
5253
act_approximator = identity, act_decoder = true;

src/Transform/fourier_transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function low_pass(ft::FourierTransform, 𝐱_fft::AbstractArray)
1414
return view(𝐱_fft, map(d -> 1:d, ft.modes)..., :, :) # [ft.modes..., in_chs, batch]
1515
end
1616

17-
const truncate_modes = low_pass
17+
truncate_modes(args...) = low_pass(args...)
1818

1919
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray)
2020
return real(ifft(𝐱_fft, 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]

src/model.jl

Lines changed: 145 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,81 @@ export
44

55
"""
66
FourierNeuralOperator(;
7-
ch=(2, 64, 64, 64, 64, 64, 128, 1),
8-
modes=(16, ),
9-
σ=gelu
10-
)
7+
ch = (2, 64, 64, 64, 64, 64, 128, 1),
8+
modes = (16, ),
9+
σ = gelu)
1110
12-
Fourier neural operator learns a neural operator with Dirichlet kernel to form a Fourier transformation.
13-
It performs Fourier transformation across infinite-dimensional function spaces and learns better than neural operator.
11+
Fourier neural operator is a operator learning model that uses Fourier kernel to perform
12+
spectral convolutions.
13+
It is a promissing way for surrogate methods, and can be regarded as a physics operator.
14+
15+
The model is comprised of
16+
a `Dense` layer to lift (d + 1)-dimensional vector field to n-dimensional vector field,
17+
and an integral kernel operator which consists of four Fourier kernels,
18+
and two `Dense` layers to project data back to the scalar field of interest space.
19+
20+
The role of each channel size described as follow:
21+
22+
```
23+
[1] input channel number
24+
↓ Dense
25+
[2] lifted channel number
26+
↓ OperatorKernel
27+
[3] mapped cahnnel number
28+
↓ OperatorKernel
29+
[4] mapped cahnnel number
30+
↓ OperatorKernel
31+
[5] mapped cahnnel number
32+
↓ OperatorKernel
33+
[6] mapped cahnnel number
34+
↓ Dense
35+
[7] projected channel number
36+
↓ Dense
37+
[8] projected channel number
38+
```
39+
40+
## Keyword Arguments
41+
42+
* `ch`: A `Tuple` or `Vector` of the 8 channel size.
43+
* `modes`: The modes to be preserved. A tuple of length `d`,
44+
where `d` is the dimension of data.
45+
* `σ`: Activation function for all layers in the model.
46+
47+
## Example
48+
49+
```julia
50+
julia> using NNlib
51+
52+
julia> FourierNeuralOperator(;
53+
ch = (2, 64, 64, 64, 64, 64, 128, 1),
54+
modes = (16,),
55+
σ = gelu)
56+
Chain(
57+
Dense(2 => 64), # 192 parameters
58+
OperatorKernel(
59+
Dense(64 => 64), # 4_160 parameters
60+
OperatorConv(64 => 64, (16,), FourierTransform, permuted=false), # 65_536 parameters
61+
NNlib.gelu,
62+
),
63+
OperatorKernel(
64+
Dense(64 => 64), # 4_160 parameters
65+
OperatorConv(64 => 64, (16,), FourierTransform, permuted=false), # 65_536 parameters
66+
NNlib.gelu,
67+
),
68+
OperatorKernel(
69+
Dense(64 => 64), # 4_160 parameters
70+
OperatorConv(64 => 64, (16,), FourierTransform, permuted=false), # 65_536 parameters
71+
NNlib.gelu,
72+
),
73+
OperatorKernel(
74+
Dense(64 => 64), # 4_160 parameters
75+
OperatorConv(64 => 64, (16,), FourierTransform, permuted=false), # 65_536 parameters
76+
identity,
77+
),
78+
Dense(64 => 128, gelu), # 8_320 parameters
79+
Dense(128 => 1), # 129 parameters
80+
) # Total: 18 arrays, 287_425 parameters, 2.098 MiB.
81+
```
1482
"""
1583
function FourierNeuralOperator(;
1684
ch = (2, 64, 64, 64, 64, 64, 128, 1),
@@ -29,14 +97,79 @@ end
2997

3098
"""
3199
MarkovNeuralOperator(;
32-
ch=(1, 64, 64, 64, 64, 64, 1),
33-
modes=(24, 24),
34-
σ=gelu
35-
)
100+
ch = (1, 64, 64, 64, 64, 64, 1),
101+
modes = (24, 24),
102+
σ = gelu)
36103
37104
Markov neural operator learns a neural operator with Fourier operators.
38-
With only one time step information of learning, it can predict the following few steps with low loss
39-
by linking the operators into a Markov chain.
105+
With only one time step information of learning, it can predict the following few steps
106+
with low loss by linking the operators into a Markov chain.
107+
108+
The model is comprised of
109+
a `Dense` layer to lift d-dimensional vector field to n-dimensional vector field,
110+
and an integral kernel operator which consists of four Fourier kernels,
111+
and a `Dense` layers to project data back to the scalar field of interest space.
112+
113+
The role of each channel size described as follow:
114+
115+
```
116+
[1] input channel number
117+
↓ Dense
118+
[2] lifted channel number
119+
↓ OperatorKernel
120+
[3] mapped cahnnel number
121+
↓ OperatorKernel
122+
[4] mapped cahnnel number
123+
↓ OperatorKernel
124+
[5] mapped cahnnel number
125+
↓ OperatorKernel
126+
[6] mapped cahnnel number
127+
↓ Dense
128+
[7] projected channel number
129+
```
130+
131+
## Keyword Arguments
132+
133+
* `ch`: A `Tuple` or `Vector` of the 7 channel size.
134+
* `modes`: The modes to be preserved. A tuple of length `d`,
135+
where `d` is the dimension of data.
136+
* `σ`: Activation function for all layers in the model.
137+
138+
## Example
139+
140+
```julia
141+
julia> using NNlib
142+
143+
julia> MarkovNeuralOperator(;
144+
ch = (1, 64, 64, 64, 64, 64, 1),
145+
modes = (24, 24),
146+
σ = gelu)
147+
Chain(
148+
Dense(1 => 64), # 128 parameters
149+
OperatorKernel(
150+
Dense(64 => 64), # 4_160 parameters
151+
OperatorConv(64 => 64, (24, 24), FourierTransform, permuted=false), # 2_359_296 parameters
152+
NNlib.gelu,
153+
),
154+
OperatorKernel(
155+
Dense(64 => 64), # 4_160 parameters
156+
OperatorConv(64 => 64, (24, 24), FourierTransform, permuted=false), # 2_359_296 parameters
157+
NNlib.gelu,
158+
),
159+
OperatorKernel(
160+
Dense(64 => 64), # 4_160 parameters
161+
OperatorConv(64 => 64, (24, 24), FourierTransform, permuted=false), # 2_359_296 parameters
162+
NNlib.gelu,
163+
),
164+
OperatorKernel(
165+
Dense(64 => 64), # 4_160 parameters
166+
OperatorConv(64 => 64, (24, 24), FourierTransform, permuted=false), # 2_359_296 parameters
167+
NNlib.gelu,
168+
),
169+
Dense(64 => 1), # 65 parameters
170+
) # Total: 16 arrays, 9_454_017 parameters, 72.066 MiB.
171+
172+
```
40173
"""
41174
function MarkovNeuralOperator(;
42175
ch = (1, 64, 64, 64, 64, 64, 1),

src/operator_kernel.jl

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,23 @@ function OperatorConv{P}(weight::T,
1919
end
2020

2121
"""
22-
OperatorConv(
23-
ch, modes, transform;
24-
init=c_glorot_uniform, permuted=false, T=ComplexF32
25-
)
22+
OperatorConv(ch, modes, transform;
23+
init=c_glorot_uniform, permuted=false, T=ComplexF32)
2624
2725
## Arguments
2826
29-
* `ch`: Input and output channel size, e.g. `64=>64`.
30-
* `modes`: The modes to be preserved.
27+
* `ch`: A `Pair` of input and output channel size `ch_in=>ch_out`, e.g. `64=>64`.
28+
* `modes`: The modes to be preserved. A tuple of length `d`,
29+
where `d` is the dimension of data.
3130
* `Transform`: The trafo to operate the transformation.
31+
32+
## Keyword Arguments
33+
34+
* `init`: Initial function to initialize parameters.
3235
* `permuted`: Whether the dim is permuted. If `permuted=true`, layer accepts
33-
data in the order of `(ch, ..., batch)`, otherwise the order is `(..., ch, batch)`.
36+
data in the order of `(ch, x_1, ... , x_d , batch)`,
37+
otherwise the order is `(x_1, ... , x_d, ch, batch)`.
38+
* `T`: Data type of parameters.
3439
3540
## Example
3641
@@ -74,7 +79,11 @@ ispermuted(::OperatorConv{P}) where {P} = P
7479

7580
function Base.show(io::IO, l::OperatorConv{P}) where {P}
7681
print(io,
77-
"OperatorConv($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)")
82+
"OperatorConv(" *
83+
"$(l.in_channel) => $(l.out_channel), " *
84+
"$(l.transform.modes), " *
85+
"$(nameof(typeof(l.transform))), " *
86+
"permuted=$P)")
7887
end
7988

8089
function operator_conv(m::OperatorConv, 𝐱::AbstractArray)
@@ -116,11 +125,17 @@ end
116125
117126
## Arguments
118127
119-
* `ch`: Input and output channel size for spectral convolution, e.g. `64=>64`.
120-
* `modes`: The Fourier modes to be preserved for spectral convolution.
128+
* `ch`: A `Pair` of input and output channel size for spectral convolution `in_ch=>out_ch`,
129+
e.g. `64=>64`.
130+
* `modes`: The modes to be preserved for spectral convolution. A tuple of length `d`,
131+
where `d` is the dimension of data.
121132
* `σ`: Activation function.
133+
134+
## Keyword Arguments
135+
122136
* `permuted`: Whether the dim is permuted. If `permuted=true`, layer accepts
123-
data in the order of `(ch, ..., batch)`, otherwise the order is `(..., ch, batch)`.
137+
data in the order of `(ch, x_1, ... , x_d , batch)`,
138+
otherwise the order is `(x_1, ... , x_d, ch, batch)`.
124139
125140
## Example
126141
@@ -176,6 +191,10 @@ Graph kernel layer.
176191
* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP.
177192
* `ch`: Channel size for linear transform, e.g. `32`.
178193
* `σ`: Activation function.
194+
195+
## Keyword Arguments
196+
197+
* `init`: Initial function to initialize parameters.
179198
"""
180199
struct GraphKernel{A, B, F} <: MessagePassing
181200
linear::A

0 commit comments

Comments
 (0)