11export ChebyshevTransform
22
3- struct ChebyshevTransform{N, S}<: AbstractTransform
3+ struct ChebyshevTransform{N, S} <: AbstractTransform
44 modes:: NTuple{N, S} # N == ndims(x)
55end
66
@@ -11,7 +11,7 @@ function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N}
1111end
1212
1313function truncate_modes (t:: ChebyshevTransform , 𝐱̂:: AbstractArray )
14- return view (𝐱̂, map (d-> 1 : d, t. modes)... , :, :) # [t.modes..., in_chs, batch]
14+ return view (𝐱̂, map (d -> 1 : d, t. modes)... , :, :) # [t.modes..., in_chs, batch]
1515end
1616
1717function inverse (t:: ChebyshevTransform{N} , 𝐱̂:: AbstractArray ) where {N}
2121
2222function ChainRulesCore. rrule (:: typeof (FFTW. r2r), x:: AbstractArray , kind, dims)
2323 y = FFTW. r2r (x, kind, dims)
24- (M,) = size (x)[dims]
25- r2r_pullback (Δ) = (NoTangent (), ∇r2r (unthunk (Δ), kind, dims, M), NoTangent (), NoTangent ())
24+ r2r_pullback (Δ) = (NoTangent (), ∇r2r (unthunk (Δ), kind, dims), NoTangent (), NoTangent ())
2625 return y, r2r_pullback
2726end
2827
29- function ∇r2r (Δ:: AbstractArray , kind, dims, M)
30- # derivative of r2r turns out to be r2r + a rank 4 correction
28+ function ∇r2r (Δ:: AbstractArray{T} , kind, dims) where {T}
29+ # derivative of r2r turns out to be r2r
3130 Δx = FFTW. r2r (Δ, kind, dims)
32-
33- # a1 = fill!(similar(A, M), one(T))
31+
32+ # rank 4 correction: needs @bischtob to elaborate the reason using this.
33+ # (M,) = size(Δ)[dims]
34+ # a1 = fill!(similar(Δ, M), one(T))
3435 # CUDA.@allowscalar a1[1] = a1[end] = zero(T)
3536
36- # a2 = fill!(similar(A , M), one(T))
37+ # a2 = fill!(similar(Δ , M), one(T))
3738 # a2[1:2:end] .= -one(T)
3839 # CUDA.@allowscalar a2[1] = a2[end] = zero(T)
3940
40- # e1 = fill!(similar(A , M), zero(T))
41+ # e1 = fill!(similar(Δ , M), zero(T))
4142 # CUDA.@allowscalar e1[1] = one(T)
4243
43- # eN = fill!(similar(A , M), zero(T))
44+ # eN = fill!(similar(Δ , M), zero(T))
4445 # CUDA.@allowscalar eN[end] = one(T)
4546
46- # @tullio Δx[s, i, b] +=
47- # a1[i] * e1[k] * Δ[s, k, b] - a2[i] * eN[k] * Δ[s, k, b]
48- # @tullio Δx[s, i, b] +=
49- # eN[i] * a2[k] * Δ[s, k, b] - e1[i] * a1[k] * Δ[s, k, b]
47+ # Δx .+= @. a1' * sum(e1' .* Δ, dims=2) - a2' * sum(eN' .* Δ, dims=2)
48+ # Δx .+= @. eN' * sum(a2' .* Δ, dims=2) - e1' * sum(a1' .* Δ, dims=2)
5049 return Δx
5150end
0 commit comments