Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit fed68b8

Browse files
committed
fix formatting
1 parent cdea34e commit fed68b8

File tree

4 files changed

+17
-8
lines changed

4 files changed

+17
-8
lines changed

src/Transform/chebyshev_transform.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray)
1414
return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
1515
end
1616

17-
function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T,N}, M::NTuple{N, Int64}) where {T,N}
17+
function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T, N},
18+
M::NTuple{N, Int64}) where {T, N}
1819
normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1)))
1920
return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch]
2021
end

src/Transform/fourier_transform.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ end
1616

1717
truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft)
1818

19-
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T,N}, M::NTuple{N, Int64}) where {T,N}
19+
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T, N},
20+
M::NTuple{N, Int64}) where {T, N}
2021
return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]
2122
end

test/Transform/chebyshev_transform.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
@test ndims(t) == 3
99
@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
1010
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
11-
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch)
11+
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) ==
12+
(3, 4, 5, ch, batch)
1213

1314
g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)), size(𝐱))), 𝐱)
1415
@test size(g[1]) == (30, 40, 50, ch, batch)

test/Transform/fourier_transform.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@
77

88
@test size(transform(ft, 𝐱)) == (16, 40, 50, ch, batch)
99
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
10-
@test size(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, 𝐱)), size(transform(ft, 𝐱)) ),
11-
size(𝐱))) == (30, 40, 50, ch, batch)
10+
@test size(inverse(ft,
11+
NeuralOperators.pad_modes(truncate_modes(ft, transform(ft, 𝐱)),
12+
size(transform(ft, 𝐱))),
13+
size(𝐱))) == (30, 40, 50, ch, batch)
1214

13-
g = Zygote.gradient(x -> sum(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, x)),
14-
(16, 40, 50, ch, batch) ), (30, 40, 50, ch, batch) )), 𝐱)
15+
g = Zygote.gradient(x -> sum(inverse(ft,
16+
NeuralOperators.pad_modes(truncate_modes(ft,
17+
transform(ft,
18+
x)),
19+
(16, 40, 50, ch, batch)),
20+
(30, 40, 50, ch, batch))), 𝐱)
1521
@test size(g[1]) == (30, 40, 50, ch, batch)
16-
end
22+
end

0 commit comments

Comments
 (0)