@@ -18,7 +18,6 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
1818 isapprox_kwargs = (; rtol= rtol, atol= atol, kwargs... )
1919
2020 @testset " test_scalar: $f at $z " begin
21- _ensure_not_running_on_functor (f, " test_scalar" )
2221 # z = x + im * y
2322 # Ω = u(x, y) + im * v(x, y)
2423 Ω = f (z; fkwargs... )
@@ -30,8 +29,9 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
3029 test_frule (f, z ⊢ Δx; rule_test_kwargs... )
3130 if z isa Complex
3231 # check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
33- _, real_tangent = frule ((ZeroTangent (), real (Δx)), f, z; fkwargs... )
34- _, embedded_tangent = frule ((ZeroTangent (), Δx), f, z; fkwargs... )
32+ ḟ = rand_tangent (f)
33+ _, real_tangent = frule ((ḟ, real (Δx)), f, z; fkwargs... )
34+ _, embedded_tangent = frule ((ḟ, Δx), f, z; fkwargs... )
3535 test_approx (real_tangent, embedded_tangent; isapprox_kwargs... )
3636 end
3737 end
7070 test_frule(f, args..; kwargs...)
7171
7272# Arguments
73- - `f`: Function for which the `frule` should be tested.
73+ - `f`: Function for which the `frule` should be tested. Can also provide `f ⊢ ḟ`.
7474- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
7575 - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
7676 - `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
@@ -99,25 +99,29 @@ function test_frule(
9999 # To simplify some of the calls we make later lets group the kwargs for reuse
100100 isapprox_kwargs = (; rtol= rtol, atol= atol, kwargs... )
101101
102+ # and define a helper closure
103+ call_on_copy (f, xs... ) = deepcopy (f)(deepcopy (xs)... ; deepcopy (fkwargs)... )
104+
102105 @testset " test_frule: $f on $(_string_typeof (args)) " begin
103- _ensure_not_running_on_functor (f, " test_frule" )
104106
105- xẋs = auto_primal_and_tangent .(args)
106- xs = primal .(xẋs)
107- ẋs = tangent .(xẋs)
108- if check_inferred && _is_inferrable (f, deepcopy (xs)... ; deepcopy (fkwargs)... )
109- _test_inferred (frule, (NoTangent (), deepcopy (ẋs)... ), f, deepcopy (xs)... ; deepcopy (fkwargs)... )
107+ primals_and_tangents = auto_primal_and_tangent .((f, args... ))
108+ primals = primal .(primals_and_tangents)
109+ tangents = tangent .(primals_and_tangents)
110+
111+ if check_inferred && _is_inferrable (deepcopy (primals)... ; deepcopy (fkwargs)... )
112+ _test_inferred (frule, deepcopy (tangents), deepcopy (primals)... ; deepcopy (fkwargs)... )
110113 end
111- res = frule ((NoTangent (), deepcopy (ẋs)... ), f, deepcopy (xs)... ; deepcopy (fkwargs)... )
112- res === nothing && throw (MethodError (frule, typeof ((f, xs... ))))
114+
115+ res = frule (deepcopy (tangents), deepcopy (primals)... ; deepcopy (fkwargs)... )
116+ res === nothing && throw (MethodError (frule, typeof (primals)))
113117 @test_msg " The frule should return (y, ∂y), not $res ." res isa Tuple{Any,Any}
114118 Ω_ad, dΩ_ad = res
115- Ω = f ( deepcopy (xs) ... ; deepcopy (fkwargs) ... )
119+ Ω = call_on_copy (primals ... )
116120 test_approx (Ω_ad, Ω; isapprox_kwargs... )
117121
118122 # TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
119- ẋs_is_ignored = isa .(ẋs , Union{Nothing,NoTangent})
120- if any (ẋs .== nothing )
123+ is_ignored = isa .(tangents , Union{Nothing,NoTangent})
124+ if any (tangents .== nothing )
121125 Base. depwarn (
122126 " test_frule(f, k ⊢ nothing) is deprecated, use " *
123127 " test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks" ,
@@ -126,7 +130,7 @@ function test_frule(
126130 end
127131
128132 # Correctness testing via finite differencing.
129- dΩ_fd = _make_jvp_call (fdm, (xs ... ) -> f ( deepcopy (xs) ... ; deepcopy (fkwargs) ... ) , Ω, xs, ẋs, ẋs_is_ignored )
133+ dΩ_fd = _make_jvp_call (fdm, call_on_copy , Ω, primals, tangents, is_ignored )
130134 test_approx (dΩ_ad, dΩ_fd; isapprox_kwargs... )
131135
132136 acc = output_tangent isa Auto ? rand_tangent (Ω) : output_tangent
@@ -138,14 +142,14 @@ end
138142 test_rrule(f, args...; kwargs...)
139143
140144# Arguments
141- - `f`: Function to which rule should be applied.
142- - `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ `
145+ - `f`: Function to which rule should be applied. Can also provide `f ⊢ f̄`.
146+ - `args` either the primal args `x`, or primals and their tangents: `x ⊢ x̄ `
143147 - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
144148 - `x̄`: currently accumulated cotangent, will be generated automatically if not provided
145149 Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
146150
147151# Keyword Arguments
148- - `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
152+ - `output_tangent` the seed to propagate backward for testing (technically a cotangent).
149153 should be a differential for the output of `f`. Is set automatically if not provided.
150154 - `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
151155 - If `check_inferred=true`, then the inferrability of the `rrule` is checked
@@ -167,63 +171,66 @@ function test_rrule(
167171 # To simplify some of the calls we make later lets group the kwargs for reuse
168172 isapprox_kwargs = (; rtol= rtol, atol= atol, kwargs... )
169173
174+ # and define helper closure over fkwargs
175+ call (f, xs... ) = f (xs... ; fkwargs... )
176+
170177 @testset " test_rrule: $f on $(_string_typeof (args)) " begin
171- _ensure_not_running_on_functor (f, " test_rrule" )
172178
173179 # Check correctness of evaluation.
174- xx̄s = auto_primal_and_tangent .(args)
175- xs = primal .(xx̄s)
176- accumulated_x̄ = tangent .(xx̄s)
177- if check_inferred && _is_inferrable (f, xs... ; fkwargs... )
178- _test_inferred (rrule, f, xs... ; fkwargs... )
180+ primals_and_tangents = auto_primal_and_tangent .((f, args... ))
181+ primals = primal .(primals_and_tangents)
182+ accum_cotangents = tangent .(primals_and_tangents)
183+
184+ if check_inferred && _is_inferrable (primals... ; fkwargs... )
185+ _test_inferred (rrule, primals... ; fkwargs... )
179186 end
180- res = rrule (f, xs ... ; fkwargs... )
181- res === nothing && throw (MethodError (rrule, typeof ((f, xs ... ))))
187+ res = rrule (primals ... ; fkwargs... )
188+ res === nothing && throw (MethodError (rrule, typeof ((primals ... ))))
182189 y_ad, pullback = res
183- y = f (xs ... ; fkwargs ... )
190+ y = call (primals ... )
184191 test_approx (y_ad, y; isapprox_kwargs... ) # make sure primal is correct
185192
186193 ȳ = output_tangent isa Auto ? rand_tangent (y) : output_tangent
187194
188195 check_inferred && _test_inferred (pullback, ȳ)
189- ∂s = pullback (ȳ)
190- ∂s isa Tuple || error (" The pullback must return (∂self, ∂args...), not $∂s ." )
191- ∂self = ∂s[1 ]
192- x̄s_ad = ∂s[2 : end ]
193- @test ∂self === NoTangent () # No internal fields
194- msg = " The pullback should return 1 cotangent for each primal input."
195- @test_msg msg length (x̄s_ad) == length (args)
196+ ad_cotangents = pullback (ȳ)
197+ ad_cotangents isa Tuple || error (" The pullback must return (∂self, ∂args...), not $∂s ." )
198+ msg = " The pullback should return 1 cotangent for the primal and each primal input."
199+ @test_msg msg length (ad_cotangents) == 1 + length (args)
196200
197201 # Correctness testing via finite differencing.
198202 # TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
199- x̄s_is_dne = isa .(accumulated_x̄ , Union{Nothing,NoTangent})
200- if any (accumulated_x̄ .== nothing )
203+ is_ignored = isa .(accum_cotangents , Union{Nothing, NoTangent})
204+ if any (accum_cotangents .== nothing )
201205 Base. depwarn (
202206 " test_rrule(f, k ⊢ nothing) is deprecated, use " *
203207 " test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks" ,
204208 :test_rrule ,
205209 )
206210 end
207211
208- x̄s_fd = _make_j′vp_call (fdm, (xs... ) -> f (xs... ; fkwargs... ), ȳ, xs, x̄s_is_dne)
209- for (accumulated_x̄, x̄_ad, x̄_fd) in zip (accumulated_x̄, x̄s_ad, x̄s_fd)
210- if accumulated_x̄ isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
211- @assert x̄_fd === nothing # this is how `_make_j′vp_call` works
212- x̄_ad isa ZeroTangent && error (
213- " The pullback in the rrule for $f function should use NoTangent()" *
212+ fd_cotangents = _make_j′vp_call (fdm, call, ȳ, primals, is_ignored)
213+
214+ for (accum_cotangent, ad_cotangent, fd_cotangent) in zip (
215+ accum_cotangents, ad_cotangents, fd_cotangents
216+ )
217+ if accum_cotangent isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
218+ @assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
219+ ad_cotangent isa ZeroTangent && error (
220+ " The pullback in the rrule should use NoTangent()" *
214221 " rather than ZeroTangent() for non-perturbable arguments." ,
215222 )
216- @test x̄_ad isa NoTangent # we said it wasn't differentiable.
223+ @test ad_cotangent isa NoTangent # we said it wasn't differentiable.
217224 else
218- x̄_ad isa AbstractThunk && check_inferred && _test_inferred (unthunk, x̄_ad )
225+ ad_cotangent isa AbstractThunk && check_inferred && _test_inferred (unthunk, ad_cotangent )
219226
220- # The main test of the actual deriviative being correct:
221- test_approx (x̄_ad, x̄_fd ; isapprox_kwargs... )
222- _test_add!!_behaviour (accumulated_x̄, x̄_ad ; isapprox_kwargs... )
227+ # The main test of the actual derivative being correct:
228+ test_approx (ad_cotangent, fd_cotangent ; isapprox_kwargs... )
229+ _test_add!!_behaviour (accum_cotangent, ad_cotangent ; isapprox_kwargs... )
223230 end
224231 end
225232
226- check_thunking_is_appropriate (x̄s_ad )
233+ check_thunking_is_appropriate (ad_cotangents )
227234 end # top-level testset
228235end
229236
@@ -236,16 +243,6 @@ function check_thunking_is_appropriate(x̄s)
236243 end
237244end
238245
239- function _ensure_not_running_on_functor (f, name)
240- # if x itself is a Type, then it is a constructor, thus not a functor.
241- # This also catchs UnionAll constructors which have a `:var` and `:body` fields
242- f isa Type && return nothing
243-
244- if fieldcount (typeof (f)) > 0
245- throw (ArgumentError (" $name cannot be used on closures/functors (such as $f )" ))
246- end
247- end
248-
249246"""
250247 @maybe_inferred [Type] f(...)
251248
0 commit comments