Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 9ac12ad

Browse files
committed
fix tests
1 parent a4e887f commit 9ac12ad

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

test/Transform/fourier_transform.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
ft = FourierTransform((3, 4, 5))
77

8-
@test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch)
8+
@test size(transform(ft, 𝐱)) == (16, 40, 50, ch, batch)
99
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
10-
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch)
10+
@test size(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, 𝐱)), size(transform(ft, 𝐱)) ),
11+
size(𝐱))) == (30, 40, 50, ch, batch)
1112

12-
g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)), size(𝐱))), 𝐱)
13+
g = Zygote.gradient(x -> sum(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, x)),
14+
(16, 40, 50, ch, batch) ), (30, 40, 50, ch, batch) )), 𝐱)
1315
@test size(g[1]) == (30, 40, 50, ch, batch)
14-
end
16+
end

test/operator_kernel.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ end
7171
end
7272

7373
@testset "2D OperatorConv" begin
74-
modes = (16, 16)
74+
modes = (10, 10)
7575
ch = 64 => 64
7676

7777
m = Chain(Dense(1, 64),
@@ -87,7 +87,7 @@ end
8787
end
8888

8989
@testset "permuted 2D OperatorConv" begin
90-
modes = (16, 16)
90+
modes = (10, 10)
9191
ch = 64 => 64
9292

9393
m = Chain(Conv((1, 1), 1 => 64),
@@ -104,7 +104,7 @@ end
104104
end
105105

106106
@testset "2D OperatorKernel" begin
107-
modes = (16, 16)
107+
modes = (10, 10)
108108
ch = 64 => 64
109109

110110
m = Chain(Dense(1, 64),
@@ -119,7 +119,7 @@ end
119119
end
120120

121121
@testset "permuted 2D OperatorKernel" begin
122-
modes = (16, 16)
122+
modes = (10, 10)
123123
ch = 64 => 64
124124

125125
m = Chain(Conv((1, 1), 1 => 64),

0 commit comments

Comments
 (0)