File tree Expand file tree Collapse file tree 3 files changed +18
-2
lines changed Expand file tree Collapse file tree 3 files changed +18
-2
lines changed Original file line number Diff line number Diff line change 11name = " ChainRulesTestUtils"
22uuid = " cdddcdb0-9152-4a09-a978-84456f9df70a"
3- version = " 1.2.1 "
3+ version = " 1.2.2 "
44
55[deps ]
66ChainRulesCore = " d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Original file line number Diff line number Diff line change @@ -242,6 +242,7 @@ function test_rrule(
242242
243243 if check_thunked_output_tangent
244244 test_approx (ad_cotangents, pullback (@thunk (ȳ)), " pulling back a thunk:" )
245+ check_inferred && _test_inferred (pullback, @thunk (ȳ))
245246 end
246247 end # top-level testset
247248end
Original file line number Diff line number Diff line change 5757abstract type MySpecialTrait end
5858struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
5959
60-
6160@testset " testers.jl" begin
6261 @testset " test_scalar" begin
6362 @testset " Ensure correct rules succeed" begin
@@ -711,4 +710,20 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
711710
712711 ChainRulesTestUtils. TEST_INFERRED[] = true
713712 end
713+
714+ @testset " inference of thunked cotangents" begin
715+ my_id (x) = x
716+ function ChainRulesCore. rrule (:: typeof (my_id), x)
717+ my_id_pb (ȳ) = (NoTangent (), ȳ)
718+ function my_id_pb (ȳ:: AbstractThunk )
719+ precision = rand () > 0.5 ? Float64 : Float32
720+ return (NoTangent (), precision (unthunk (ȳ)))
721+ end
722+ return x, my_id_pb
723+ end
724+
725+ @test errors (() -> test_rrule (my_id, 2.0 ))
726+ test_rrule (my_id, 2.0 ; check_inferred= false )
727+ test_rrule (my_id, 2.0 ; check_thunked_output_tangent= false )
728+ end
714729end
You can’t perform that action at this time.
0 commit comments