Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a4e887f

Browse files
committed
use rfft instead of fft
1 parent 60b7726 commit a4e887f

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

src/Transform/chebyshev_transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray)
1414
return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
1515
end
1616

17-
function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N}
17+
function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray, M) where {N}
1818
normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1)))
1919
return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch]
2020
end

src/Transform/fourier_transform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ end
77
Base.ndims(::FourierTransform{N}) where {N} = N
88

99
function transform(ft::FourierTransform, 𝐱::AbstractArray)
10-
return fft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch]
10+
return rfft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch]
1111
end
1212

1313
function low_pass(ft::FourierTransform, 𝐱_fft::AbstractArray)
@@ -16,6 +16,6 @@ end
1616

1717
truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft)
1818

19-
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray)
20-
return real(ifft(𝐱_fft, 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]
19+
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray, M)
20+
return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]
2121
end

src/operator_kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function operator_conv(m::OperatorConv, 𝐱::AbstractArray)
9292
𝐱_padded = pad_modes(𝐱_applied_pattern,
9393
(size(𝐱_transformed)[1:(end - 2)]...,
9494
size(𝐱_applied_pattern)[(end - 1):end]...)) # [size(x)..., out_chs, batch] <- [modes..., out_chs, batch]
95-
𝐱_inversed = inverse(m.transform, 𝐱_padded)
95+
𝐱_inversed = inverse(m.transform, 𝐱_padded, size(𝐱))
9696

9797
return 𝐱_inversed
9898
end

test/Transform/chebyshev_transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
@test ndims(t) == 3
99
@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
1010
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
11-
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch)
11+
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch)
1212

13-
g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱)
13+
g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)), size(𝐱))), 𝐱)
1414
@test size(g[1]) == (30, 40, 50, ch, batch)
1515
end

test/Transform/fourier_transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
@test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch)
99
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
10-
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch)
10+
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch)
1111

12-
g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)))), 𝐱)
12+
g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)), size(𝐱))), 𝐱)
1313
@test size(g[1]) == (30, 40, 50, ch, batch)
1414
end

0 commit comments

Comments
 (0)