1- # ffts
1+ module AbstractFFTsChainRulesCoreExt
2+
3+ using AbstractFFTs
4+ import ChainRulesCore
5+
26function ChainRulesCore. frule ((_, Δx, _), :: typeof (fft), x:: AbstractArray , dims)
37 y = fft (x, dims)
48 Δy = fft (Δx, dims)
@@ -33,7 +37,9 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
3337
3438 project_x = ChainRulesCore. ProjectTo (x)
3539 function rfft_pullback (ȳ)
36- x̄ = project_x (brfft (ChainRulesCore. unthunk (ȳ) ./ scale, d, dims))
40+ ybar = ChainRulesCore. unthunk (ȳ)
41+ _scale = convert (typeof (ybar),scale)
42+ x̄ = project_x (brfft (ybar ./ _scale, d, dims))
3743 return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
3844 end
3945 return y, rfft_pullback
@@ -46,7 +52,7 @@ function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dim
4652end
4753function ChainRulesCore. rrule (:: typeof (ifft), x:: AbstractArray , dims)
4854 y = ifft (x, dims)
49- invN = normalization (y, dims)
55+ invN = AbstractFFTs . normalization (y, dims)
5056 project_x = ChainRulesCore. ProjectTo (x)
5157 function ifft_pullback (ȳ)
5258 x̄ = project_x (invN .* fft (ChainRulesCore. unthunk (ȳ), dims))
@@ -66,7 +72,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
6672 # compute scaling factors
6773 halfdim = first (dims)
6874 n = size (x, halfdim)
69- invN = normalization (y, dims)
75+ invN = AbstractFFTs . normalization (y, dims)
7076 twoinvN = 2 * invN
7177 scale = reshape (
7278 [i == 1 || (i == n && 2 * (i - 1 ) == d) ? invN : twoinvN for i in 1 : n],
@@ -75,7 +81,9 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
7581
7682 project_x = ChainRulesCore. ProjectTo (x)
7783 function irfft_pullback (ȳ)
78- x̄ = project_x (scale .* rfft (real .(ChainRulesCore. unthunk (ȳ)), dims))
84+ ybar = ChainRulesCore. unthunk (ȳ)
85+ _scale = convert (typeof (ybar),scale)
86+ x̄ = project_x (_scale .* rfft (real .(ybar), dims))
7987 return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent ()
8088 end
8189 return y, irfft_pullback
@@ -152,12 +160,12 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
152160end
153161
154162# plans
155- function ChainRulesCore. frule ((_, _, Δx), :: typeof (* ), P:: Plan , x:: AbstractArray )
163+ function ChainRulesCore. frule ((_, _, Δx), :: typeof (* ), P:: AbstractFFTs. Plan , x:: AbstractArray )
156164 y = P * x
157165 Δy = P * Δx
158166 return y, Δy
159167end
160- function ChainRulesCore. rrule (:: typeof (* ), P:: Plan , x:: AbstractArray )
168+ function ChainRulesCore. rrule (:: typeof (* ), P:: AbstractFFTs. Plan , x:: AbstractArray )
161169 y = P * x
162170 project_x = ChainRulesCore. ProjectTo (x)
163171 Pt = P'
@@ -168,22 +176,25 @@ function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
168176 return y, mul_plan_pullback
169177end
170178
171- function ChainRulesCore. frule ((_, ΔP, Δx), :: typeof (* ), P:: ScaledPlan , x:: AbstractArray )
179+ function ChainRulesCore. frule ((_, ΔP, Δx), :: typeof (* ), P:: AbstractFFTs. ScaledPlan , x:: AbstractArray )
172180 y = P * x
173181 Δy = P * Δx .+ (ΔP. scale / P. scale) .* y
174182 return y, Δy
175183end
176- function ChainRulesCore. rrule (:: typeof (* ), P:: ScaledPlan , x:: AbstractArray )
184+ function ChainRulesCore. rrule (:: typeof (* ), P:: AbstractFFTs. ScaledPlan , x:: AbstractArray )
177185 y = P * x
178186 Pt = P'
179187 scale = P. scale
180188 project_x = ChainRulesCore. ProjectTo (x)
181189 project_scale = ChainRulesCore. ProjectTo (scale)
182190 function mul_scaledplan_pullback (ȳ)
183191 x̄ = ChainRulesCore. @thunk (project_x (Pt * ȳ))
184- scale_tangent = ChainRulesCore. @thunk (project_scale (dot (y, ȳ) / conj (scale)))
192+ scale_tangent = ChainRulesCore. @thunk (project_scale (AbstractFFTs . dot (y, ȳ) / conj (scale)))
185193 plan_tangent = ChainRulesCore. Tangent {typeof(P)} (;p= ChainRulesCore. NoTangent (), scale= scale_tangent)
186194 return ChainRulesCore. NoTangent (), plan_tangent, x̄
187195 end
188196 return y, mul_scaledplan_pullback
189197end
198+
199+ end # module
200+
0 commit comments