Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit cc0fa11

Browse files
committed
complete SparseKernel{N}
1 parent aa85e56 commit cc0fa11

File tree

3 files changed

+83
-51
lines changed

3 files changed

+83
-51
lines changed

src/wavelet.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,44 @@
1-
struct SparseKernel{T,S}
2-
k::Int
3-
conv_blk::S
4-
out_weight::T
1+
export
2+
SparseKernel,
3+
SparseKernel1D,
4+
SparseKernel2D,
5+
SparseKernel3D
6+
7+
8+
struct SparseKernel{N,T,S}
9+
conv_blk::T
10+
out_weight::S
11+
end
12+
13+
function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
14+
input_dim, emb_dim = ch
15+
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
16+
W_out = Dense(emb_dim, input_dim; init=init)
17+
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
518
end
619

7-
function SparseKernel1d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
20+
function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
821
input_dim = c*k
922
emb_dim = 128
10-
conv = Conv((3,), input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
11-
W_out = Dense(emb_dim, input_dim; init=init)
12-
return SparseKernel(k, conv, W_out)
23+
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
1324
end
1425

15-
function SparseKernel2d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
26+
function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
1627
input_dim = c*k^2
1728
emb_dim = α*k^2
18-
conv = Conv((3, 3), input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
19-
W_out = Dense(emb_dim, input_dim; init=init)
20-
return SparseKernel(k, conv, W_out)
29+
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
2130
end
2231

23-
function SparseKernel3d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
32+
function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
2433
input_dim = c*k^2
2534
emb_dim = α*k^2
2635
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
2736
W_out = Dense(emb_dim, input_dim; init=init)
28-
return SparseKernel(k, conv, W_out)
37+
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
2938
end
3039

40+
Flux.@functor SparseKernel
41+
3142
function (l::SparseKernel)(X::AbstractArray)
3243
bch_sz, _, dims_r... = reverse(size(X))
3344
dims = reverse(dims_r)

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
using NeuralOperators
22
using Test
33
using Flux
4+
using Zygote
5+
using CUDA
6+
7+
CUDA.allowscalar(false)
48

59
@testset "NeuralOperators.jl" begin
610
include("fourier.jl")
11+
include("wavelet.jl")
712
include("model.jl")
813
end
914

test/wavelet.jl

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,53 @@
1-
using NeuralOperators
2-
using CUDA
3-
using Zygote
4-
5-
CUDA.allowscalar(false)
6-
7-
T = Float32
8-
k = 3
9-
batch_size = 32
10-
11-
α = 4
12-
c = 1
13-
in_chs = 20
14-
15-
16-
l1 = NeuralOperators.SparseKernel1d(k, α, c)
17-
X = rand(T, in_chs, c*k, batch_size)
18-
Y = l1(X)
19-
gradient(x->sum(l1(x)), X)
20-
21-
22-
α = 4
23-
c = 3
24-
Nx = 5
25-
Ny = 7
26-
27-
l2 = NeuralOperators.SparseKernel2d(k, α, c)
28-
X = rand(T, Nx, Ny, c*k^2, batch_size)
29-
Y = l2(X)
30-
gradient(x->sum(l2(x)), X)
31-
32-
Nz = 13
33-
34-
l3 = NeuralOperators.SparseKernel3d(k, α, c)
35-
X = rand(T, Nx, Ny, Nz, α*k^2, batch_size)
36-
Y = l3(X)
37-
gradient(x->sum(l3(x)), X)
1+
@testset "SparseKernel" begin
2+
T = Float32
3+
k = 3
4+
batch_size = 32
5+
6+
@testset "1D SparseKernel" begin
7+
α = 4
8+
c = 1
9+
in_chs = 20
10+
X = rand(T, in_chs, c*k, batch_size)
11+
12+
l1 = SparseKernel1D(k, α, c)
13+
Y = l1(X)
14+
@test l1 isa SparseKernel{1}
15+
@test size(Y) == size(X)
16+
17+
gs = gradient(()->sum(l1(X)), Flux.params(l1))
18+
@test length(gs.grads) == 4
19+
end
20+
21+
@testset "2D SparseKernel" begin
22+
α = 4
23+
c = 3
24+
Nx = 5
25+
Ny = 7
26+
X = rand(T, Nx, Ny, c*k^2, batch_size)
27+
28+
l2 = SparseKernel2D(k, α, c)
29+
Y = l2(X)
30+
@test l2 isa SparseKernel{2}
31+
@test size(Y) == size(X)
32+
33+
gs = gradient(()->sum(l2(X)), Flux.params(l2))
34+
@test length(gs.grads) == 4
35+
end
36+
37+
@testset "3D SparseKernel" begin
38+
α = 4
39+
c = 3
40+
Nx = 5
41+
Ny = 7
42+
Nz = 13
43+
X = rand(T, Nx, Ny, Nz, α*k^2, batch_size)
44+
45+
l3 = SparseKernel3D(k, α, c)
46+
Y = l3(X)
47+
@test l3 isa SparseKernel{3}
48+
@test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size)
49+
50+
gs = gradient(()->sum(l3(X)), Flux.params(l3))
51+
@test length(gs.grads) == 4
52+
end
53+
end

0 commit comments

Comments
 (0)