11# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license
22
33using AbstractFFTs
4- using AbstractFFTs: Plan
4+ using AbstractFFTs: Plan, ScaledPlan
55using ChainRulesTestUtils
6- using ChainRulesCore: NoTangent
6+ using FiniteDifferences
7+ import ChainRulesCore
78
89using LinearAlgebra
910using Random
293294 end
294295
295296 @testset " fft" begin
296- for x in (randn (2 ), randn (2 , 3 ), randn (3 , 4 , 5 ))
297- N = ndims (x)
298- complex_x = complex .(x)
297+ # Overloads to allow ChainRulesTestUtils to test rules w.r.t. ScaledPlan's. See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256
298+ InnerPlan = Union{TestPlan, InverseTestPlan, TestRPlan, InverseTestRPlan}
299+ function FiniteDifferences. to_vec (x:: InnerPlan )
300+ function FFTPlan_from_vec (x_vec:: Vector )
301+ return x
302+ end
303+ return Bool[], FFTPlan_from_vec
304+ end
305+ ChainRulesTestUtils. test_approx (:: ChainRulesCore.AbstractZero , x:: InnerPlan , msg= " " ; kwargs... ) = true
306+ ChainRulesTestUtils. rand_tangent (:: AbstractRNG , x:: InnerPlan ) = ChainRulesCore. NoTangent ()
307+
308+ for x_shape in ((2 ,), (2 , 3 ), (3 , 4 , 5 ))
309+ N = length (x_shape)
310+ x = randn (x_shape)
311+ complex_x = x + randn (x_shape) * im
299312 for dims in unique ((1 , 1 : N, N))
300313 # fft, ifft, bfft
301314 for f in (fft, ifft, bfft)
@@ -305,17 +318,17 @@ end
305318 test_rrule (f, complex_x, dims)
306319 end
307320 for pf in (plan_fft, plan_ifft, plan_bfft)
308- test_frule (* , pf (x, dims) ⊢ NoTangent () , x)
309- test_rrule (* , pf (x, dims) ⊢ NoTangent () , x)
310- test_frule (* , pf (complex_x, dims) ⊢ NoTangent () , complex_x)
311- test_rrule (* , pf (complex_x, dims) ⊢ NoTangent () , complex_x)
321+ test_frule (* , pf (x, dims), x)
322+ test_rrule (* , pf (x, dims), x)
323+ test_frule (* , pf (complex_x, dims), complex_x)
324+ test_rrule (* , pf (complex_x, dims), complex_x)
312325 end
313326
314327 # rfft
315328 test_frule (rfft, x, dims)
316329 test_rrule (rfft, x, dims)
317- test_frule (* , plan_rfft (x, dims) ⊢ NoTangent () , x)
318- test_rrule (* , plan_rfft (x, dims) ⊢ NoTangent () , x)
330+ test_frule (* , plan_rfft (x, dims), x)
331+ test_rrule (* , plan_rfft (x, dims), x)
319332
320333 # irfft, brfft
321334 for f in (irfft, brfft)
328341 end
329342 for pf in (plan_irfft, plan_brfft)
330343 for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
331- test_frule (* , pf (complex_x, d, dims) ⊢ NoTangent () , complex_x)
332- test_rrule (* , pf (complex_x, d, dims) ⊢ NoTangent () , complex_x)
344+ test_frule (* , pf (complex_x, d, dims), complex_x)
345+ test_rrule (* , pf (complex_x, d, dims), complex_x)
333346 end
334347 end
335348 end
0 commit comments