@@ -53,10 +53,6 @@ function ChainRulesCore.frule((Δf, Δx), f::Foo, x)
5353 return f (x), Δf. a + Δx
5454end
5555
56- # testing configs
57- abstract type MySpecialTrait end
58- struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
59-
6056# Type-stable derivative for test below
6157struct FVecOfTuplesPullback{T} end
6258function (f:: FVecOfTuplesPullback{T} )(Δ) where {T}
664660 test_rrule (f_notimplemented2, identity, randn ())
665661 end
666662
667- @testset " custom rrule_f" begin
668- only2x (x, y) = 2 x
669- custom (:: RuleConfig , :: typeof (only2x), x, y) = only2x (x, y), Δ -> (NoTangent (), 2 Δ, ZeroTangent ())
670- wrong1 (:: RuleConfig , :: typeof (only2x), x, y) = only2x (x, y), Δ -> (ZeroTangent (), 2 Δ, ZeroTangent ())
671- wrong2 (:: RuleConfig , :: typeof (only2x), x, y) = only2x (x, y), Δ -> (NoTangent (), 2.1 Δ, ZeroTangent ())
672- wrong3 (:: RuleConfig , :: typeof (only2x), x, y) = only2x (x, y), Δ -> (NoTangent (), 2 Δ)
673-
674- test_rrule (only2x, 2.0 , 3.0 ; rrule_f= custom, check_inferred= false )
675- @test errors (() -> test_rrule (only2x, 2.0 , 3.0 ; rrule_f= wrong1, check_inferred= false ))
676- @test fails (() -> test_rrule (only2x, 2.0 , 3.0 ; rrule_f= wrong2, check_inferred= false ))
677- @test fails (() -> test_rrule (only2x, 2.0 , 3.0 ; rrule_f= wrong3, check_inferred= false ))
678- end
679-
680- @testset " custom frule_f" begin
681- mytuple (x, y) = return 2 x, 1.0
682- T = Tuple{Float64, Float64}
683- custom (:: RuleConfig , (Δf, Δx, Δy), :: typeof (mytuple), x, y) = mytuple (x, y), Tangent {T} (2 Δx, ZeroTangent ())
684- wrong1 (:: RuleConfig , (Δf, Δx, Δy), :: typeof (mytuple), x, y) = mytuple (x, y), Tangent {T} (2.1 Δx, ZeroTangent ())
685- wrong2 (:: RuleConfig , (Δf, Δx, Δy), :: typeof (mytuple), x, y) = mytuple (x, y), Tangent {T} (2 Δx, 1.0 )
686-
687- test_frule (mytuple, 2.0 , 3.0 ; frule_f= custom, check_inferred= false )
688- @test fails (() -> test_frule (mytuple, 2.0 , 3.0 ; frule_f= wrong1, check_inferred= false ))
689- @test fails (() -> test_frule (mytuple, 2.0 , 3.0 ; frule_f= wrong2, check_inferred= false ))
690- end
691-
692- @testset " custom_config" begin
693- has_config (x) = 2 x
694- function ChainRulesCore. rrule (:: MySpecialConfig , :: typeof (has_config), x)
695- has_config_pullback (ȳ) = return (NoTangent (), 2 ȳ)
696- return has_config (x), has_config_pullback
697- end
698-
699- has_trait (x) = 2 x
700- function ChainRulesCore. rrule (:: RuleConfig{<:MySpecialTrait} , :: typeof (has_trait), x)
701- has_trait_pullback (ȳ) = return (NoTangent (), 2 ȳ)
702- return has_trait (x), has_trait_pullback
703- end
704-
705- # it works if the special config is provided
706- test_rrule (MySpecialConfig (), has_config, rand ())
707- test_rrule (MySpecialConfig (), has_trait, rand ())
708-
709- # but it doesn't work for the default config
710- errors (() -> test_rrule (has_config, rand ()), " no method matching rrule" )
711- errors (() -> test_rrule (has_trait, rand ()), " no method matching rrule" )
712- end
713-
714663 @testset " @maybe_inferred" begin
715664 f_noninferrable (x) = Ref {Real} (x)[]
716665
0 commit comments