|
2 | 2 | m1 = collect(1:3.0) |
3 | 3 | m2 = (collect(1:3.0), collect(4:6.0)) |
4 | 4 | m3 = (x = m1, y = sin, z = collect(4:6.0)) |
5 | | -m4 = (x = m1, y = m1, z = collect(4:6.0)) |
| 5 | +m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied |
6 | 6 | m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) |
7 | 7 | m6 = (a = m1, b = [4.0 + im], c = m1) |
8 | 8 | m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) |
|
72 | 72 | @test g8[3] == [[10.0]] |
73 | 73 |
|
74 | 74 | @testset "second derivative" begin |
75 | | - @test_broken gradient([1,2,3.0]) do v |
| 75 | + @test gradient([1,2,3.0]) do v |
76 | 76 | sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) |
77 | 77 | end[1] ≈ [8,16,24] |
| 78 | + # With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx: |
| 79 | + # off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ... |
| 80 | + # until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing |
| 81 | + # With Zygote, instead: |
| 82 | + # dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),) |
| 83 | + |
| 84 | + @test gradient([1,2,3.0]) do v |
| 85 | + sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1]) |
| 86 | + end[1] == [378, 378, 378] |
78 | 87 |
|
79 | | - @test_skip gradient([1,2,3.0]) do v |
80 | | - sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1]) |
81 | | - end |
| 88 | + @test_broken gradient([1,2,3.0]) do v |
| 89 | + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1]) |
| 90 | + end[1] ≈ [8,16,24] |
| 91 | + # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) |
| 92 | + # Diffractor error in perform_optic_transform |
82 | 93 | end |
83 | 94 | end |
84 | 95 |
|
@@ -109,15 +120,17 @@ end |
109 | 120 | @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] |
110 | 121 |
|
111 | 122 | @testset "second derivative" begin |
112 | | - # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} |
113 | 123 | @test_broken gradient(collect(1:6.0)) do y |
114 | 124 | sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) |
115 | 125 | end[1] ≈ [8,16,24,0,0,0] |
116 | | - # This fixes it! |
| 126 | + # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} |
| 127 | + # with Zygote, which can be fixed by: |
117 | 128 | # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) |
118 | | - @test_skip gradient(collect(1:6.0)) do y |
| 129 | + |
| 130 | + @test_broken gradient(collect(1:6.0)) do y |
119 | 131 | sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) |
120 | | - end[1] |
| 132 | + end[1] ≈ [0,0,0,32,40,48] |
| 133 | + # Not fixed by this: |
121 | 134 | # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) |
122 | 135 | end |
123 | 136 | end |
|
0 commit comments