@@ -40,7 +40,7 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
4040 @testset " split 3: forwards" begin
4141 # In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else.
4242 test_rrule (copy∘ broadcasted, BS1, flog, rand (3 ))
43- test_rrule (copy∘ broadcasted, BS1, flog, rand (3 ) .+ im)
43+ @test_skip test_rrule (copy∘ broadcasted, BS1, flog, rand (3 ) .+ im) # not OK, assumed analyticity, fixed in PR710
4444 # Also, `sin∘cos` may use this path as CFG uses frule_via_ad
4545 # TODO use different CFGs, https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/255
4646 end
@@ -177,7 +177,21 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
177177 end
178178
179179 @testset " bugs" begin
180- @test ChainRules. unbroadcast ((1 , 2 , [3 ]), [4 , 5 , [6 ]]) isa Tangent # earlier, NTuple demanded same type
181- @test ChainRules. unbroadcast (broadcasted (- , (1 , 2 ), 3 ), (4 , 5 )) == (4 , 5 ) # earlier, called ndims(::Tuple)
180+ @testset " unbroadcast with NTuple" begin # https://github.com/JuliaDiff/ChainRules.jl/pull/661
181+ @test ChainRules. unbroadcast ((1 , 2 , [3 ]), [4 , 5 , [6 ]]) isa Tangent # earlier, NTuple demanded same type
182+ @test ChainRules. unbroadcast (broadcasted (- , (1 , 2 ), 3 ), (4 , 5 )) == (4 , 5 ) # earlier, called ndims(::Tuple)
183+ end
184+ @testset " unbroadcast with Matrix{Tangent}" begin # https://github.com/JuliaDiff/ChainRules.jl/issues/708
185+ x = Base. Fix1 .(* , 1 : 3.0 )
186+ dx1 = [Tangent {Base.Fix1} (; x = i/ 2 ) for i in 1 : 3 , _ in 1 : 1 ]
187+ @test size (ChainRules. unbroadcast (x, dx1)) == size (x)
188+ dx2 = [Tangent {Base.Fix1} (; x = i/ j) for i in 1 : 3 , j in 1 : 4 ]
189+ @test size (ChainRules. unbroadcast (x, dx2)) == size (x) # was an error, convert(::ZeroTangent, ::Tangent)
190+ # sum(dx2; dims=2) isa Matrix{Union{ZeroTangent, Tangent{Base.Fix1...}}, ProjectTo copies this so that
191+ # unbroadcast(x, dx2) isa Vector{Tangent{...}}, that's probably not ideal.
192+
193+ @test sum (dx2; dims= 2 )[end ] == Tangent {Base.Fix1} (x = 6.25 ,)
194+ @test sum (dx1) isa Tangent{Base. Fix1} # no special code required
195+ end
182196 end
183197end
0 commit comments