Skip to content

Commit 42a7147

Browse files
committed
🧪 include tests for DeepONet
1 parent 4fa134c commit 42a7147

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

‎test/deeponet.jl‎

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using Test, Random, Flux
2+
3+
@testset "DeepONet" begin
4+
@testset "dimensions" begin
5+
# Test the proper construction
6+
# Branch net
7+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).branch_net.layers[end].weight) == (72,64)
8+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).branch_net.layers[end].bias) == (72,)
9+
# Trunk net
10+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).trunk_net.layers[end].weight) == (72,48)
11+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).trunk_net.layers[end].bias) == (72,)
12+
end
13+
14+
# Accept only Int as architecture parameters
15+
@test_throws MethodError DeepONet((32.5,64,72), (24,48,72), σ, tanh)
16+
@test_throws MethodError DeepONet((32,64,72), (24.1,48,72))
17+
end

‎test/runtests.jl‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Random.seed!(0)
88
include("fourierlayer.jl")
99
end
1010

11+
@testset "DeepONet" begin
12+
include("deeponet.jl")
13+
end
14+
1115
@testset "Weights" begin
1216
include("complexweights.jl")
1317
end

0 commit comments

Comments
 (0)