@@ -30,16 +30,20 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
3030 halfdim = first (dims)
3131 d = size (x, halfdim)
3232 n = size (y, halfdim)
33- scale = reshape (
34- [i == 1 || (i == n && 2 * (i - 1 ) == d) ? 1 : 2 for i in 1 : n],
35- ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
36- )
3733
3834 project_x = ChainRulesCore. ProjectTo (x)
3935 function rfft_pullback (ȳ)
4036 ybar = ChainRulesCore. unthunk (ȳ)
41- _scale = convert (typeof (ybar),scale)
42- x̄ = project_x (brfft (ybar ./ _scale, d, dims))
37+ ybar_scaled = map (ybar, CartesianIndices (ybar)) do ybar_j, j
38+ i = j[halfdim]
39+ ybar_scaled_j = if i == 1 || (i == n && 2 * (i - 1 ) == d)
40+ ybar_j
41+ else
42+ ybar_j / 2
43+ end
44+ return ybar_scaled_j
45+ end
46+ x̄ = project_x (brfft (ybar_scaled, d, dims))
4347 return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
4448 end
4549 return y, rfft_pullback
@@ -74,16 +78,20 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
7478 n = size (x, halfdim)
7579 invN = AbstractFFTs. normalization (y, dims)
7680 twoinvN = 2 * invN
77- scale = reshape (
78- [i == 1 || (i == n && 2 * (i - 1 ) == d) ? invN : twoinvN for i in 1 : n],
79- ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
80- )
8181
8282 project_x = ChainRulesCore. ProjectTo (x)
8383 function irfft_pullback (ȳ)
8484 ybar = ChainRulesCore. unthunk (ȳ)
85- _scale = convert (typeof (ybar),scale)
86- x̄ = project_x (_scale .* rfft (real .(ybar), dims))
85+ x̄_scaled = rfft (real .(ybar), dims)
86+ x̄ = project_x (map (x̄_scaled, CartesianIndices (x̄_scaled)) do x̄_scaled_j, j
87+ i = j[halfdim]
88+ x̄_j = if i == 1 || (i == n && 2 * (i - 1 ) == d)
89+ invN * x̄_scaled_j
90+ else
91+ twoinvN * x̄_scaled_j
92+ end
93+ return x̄_j
94+ end )
8795 return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent ()
8896 end
8997 return y, irfft_pullback
@@ -115,14 +123,19 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
115123 # compute scaling factors
116124 halfdim = first (dims)
117125 n = size (x, halfdim)
118- scale = reshape (
119- [i == 1 || (i == n && 2 * (i - 1 ) == d) ? 1 : 2 for i in 1 : n],
120- ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
121- )
122126
123127 project_x = ChainRulesCore. ProjectTo (x)
124128 function brfft_pullback (ȳ)
125- x̄ = project_x (scale .* rfft (real .(ChainRulesCore. unthunk (ȳ)), dims))
129+ x̄_scaled = rfft (real .(ChainRulesCore. unthunk (ȳ)), dims)
130+ x̄ = project_x (map (x̄_scaled, CartesianIndices (x̄_scaled)) do x̄_scaled_j, j
131+ i = j[halfdim]
132+ x̄_j = if i == 1 || (i == n && 2 * (i - 1 ) == d)
133+ x̄_scaled_j
134+ else
135+ 2 * x̄_scaled_j
136+ end
137+ return x̄_j
138+ end )
126139 return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent ()
127140 end
128141 return y, brfft_pullback
0 commit comments