File tree Expand file tree Collapse file tree 2 files changed +24
-4
lines changed Expand file tree Collapse file tree 2 files changed +24
-4
lines changed Original file line number Diff line number Diff line change @@ -117,7 +117,7 @@ test_scalar: relu at -0.5 | 11 11
117117
118118## Testing constructors and functors (callable objects)
119119
120- Testing constructor and functors works as you would expect. For struct ` Foo `
120+ Testing constructor and functors works as you would expect. For struct ` Foo ` ,
121121``` julia
122122struct Foo
123123 a:: Float64
@@ -127,7 +127,27 @@ Base.length(::Foo) = 1
127127Base. iterate (f:: Foo ) = iterate (f. a)
128128Base. iterate (f:: Foo , state) = iterate (f. a, state)
129129```
130- the ` f/rrule ` s can be tested by
130+
131+ after defining the constructor and functor ` f/rule ` s,
132+
133+ ``` julia
134+ function ChainRulesCore. rrule (:: Type{Foo} , val) # constructor rrule
135+ y = Foo (val)
136+ Foo_pb (ΔFoo) = (NoTangent (), unthunk (ΔFoo). a)
137+ return y, Foo_pb
138+ end
139+
140+ function ChainRulesCore. rrule (foo:: Foo , val) # functor rrule
141+ y = foo (val)
142+ function foo_pb (Δ)
143+ Δut = unthunk (Δ)
144+ return (Tangent {Foo} (;a= Δut), Δut)
145+ end
146+ return y, foo_pb
147+ end
148+ ```
149+
150+ both ` f/rrule ` s can be tested by
131151``` julia
132152test_rrule (Foo, rand ()) # constructor
133153
Original file line number Diff line number Diff line change @@ -125,7 +125,7 @@ function test_frule(
125125 end
126126
127127 res = call_on_copy (frule_f, config, tangents, primals... )
128- res === nothing && throw (MethodError (frule_f, typeof (primals)))
128+ res === nothing && throw (MethodError (frule_f, Tuple{Core . Typeof . (primals)... } ))
129129 @test_msg " The frule should return (y, ∂y), not $res ." res isa Tuple{Any,Any}
130130 Ω_ad, dΩ_ad = res
131131 Ω = call_on_copy (primals... )
@@ -201,7 +201,7 @@ function test_rrule(
201201 _test_inferred (rrule_f, config, primals... ; fkwargs... )
202202 end
203203 res = rrule_f (config, primals... ; fkwargs... )
204- res === nothing && throw (MethodError (rrule_f, typeof (primals)))
204+ res === nothing && throw (MethodError (rrule_f, Tuple{Core . Typeof . (primals)... } ))
205205 y_ad, pullback = res
206206 y = call (primals... )
207207 test_approx (y_ad, y; isapprox_kwargs... ) # make sure primal is correct
You can’t perform that action at this time.
0 commit comments