@@ -49,4 +49,66 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
4949 errors (() -> test_rrule (has_config, rand ()), " no method matching rrule" )
5050 errors (() -> test_rrule (has_trait, rand ()), " no method matching rrule" )
5151 end
52+
53+ @testset " TestConfig direct" begin
54+ poly (x) = x^ 2 + 3.2 x
55+
56+ x = 2.1
57+ config = ChainRulesTestUtils. TestConfig ()
58+
59+ @testset " rrule" begin
60+ y, pb = rrule_via_ad (config, poly, x)
61+ @test y == poly (x)
62+ test_approx (pb (1.0 ), (NoTangent (), (2 * x + 3.2 ) * 1.0 ))
63+ # and automatically
64+ test_rrule (config, poly, rand (); rrule_f= rrule_via_ad, check_inferred= false )
65+ end
66+
67+ @testset " frule" begin
68+ ḟ, ẋ = (NoTangent (), rand ())
69+ Ω, ΔΩ = frule_via_ad (config, (ḟ, ẋ), poly, x)
70+ @test Ω == poly (x)
71+ test_approx (ΔΩ, (2 * x + 3.2 ) * ẋ)
72+ # and automatically
73+ test_frule (config, poly, x; frule_f= frule_via_ad, check_inferred= false )
74+ end
75+
76+ # more functions
77+ simo (x) = (x, 2 x, 3 x)
78+ miso (x, y, z) = x+ y
79+ test_rrule (config, simo, x; rrule_f= rrule_via_ad, check_inferred= false )
80+ test_rrule (config, miso, x, 2 x, " s" ; rrule_f= rrule_via_ad, check_inferred= false )
81+
82+ test_frule (config, simo, x; frule_f= frule_via_ad, check_inferred= false )
83+ test_frule (config, miso, x, x, " s" ; frule_f= frule_via_ad, check_inferred= false )
84+ end
85+
86+ @testset " TestConfig in a rule" begin
87+ inner (x, y) = x^ 2 + 2 * y + 3
88+ outer (f, x) = 2 * f (x, 3.2 )
89+
90+ function ChainRulesCore. rrule (config:: RuleConfig{>:HasReverseMode} , :: typeof (outer), f, x)
91+ fx, pb_f = rrule_via_ad (config, f, x, 3.2 )
92+ outer_pb (ȳ) = (NoTangent (), pb_f (2 * ȳ)[1 : 2 ]. .. )
93+ return outer (f, x), outer_pb
94+ end
95+
96+ function ChainRulesCore. frule (config:: RuleConfig{>:HasForwardsMode} , (ȯuter, ḟ, ẋ), :: typeof (outer), f, x)
97+ inner, inner_dot = frule_via_ad (config, (ḟ, ẋ, NoTangent ()), f, x, 3.2 )
98+ return 2 * inner, 2 * inner_dot
99+ end
100+
101+ config = ChainRulesTestUtils. TestConfig ()
102+ test_rrule (config, outer, inner, rand (); rrule_f= rrule_via_ad, check_inferred= false )
103+ test_frule (config, outer, inner, rand (); frule_f= frule_via_ad, check_inferred= false )
104+ end
105+
106+ @testset " Catch incorrect rules" begin
107+ myid (x) = x
108+ function ChainRulesCore. rrule (config:: RuleConfig{>:HasReverseMode} , :: typeof (myid), x)
109+ wrong_pb (dy) = (NoTangent (), 8 dy)
110+ return x, wrong_pb
111+ end
112+ @test fails (() -> test_rrule (myid, 3.0 ; rrule_f= rrule_via_ad, check_inferred= false ))
113+ end
52114end
0 commit comments