Skip to content

Commit d95a147

Browse files
committed
more... the dimensionmismatch bug is not here
1 parent 756b450 commit d95a147

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

src/destructure.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
33
const NoT = NoTangent()
44

5+
base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure))
6+
base(dx::Tangent{Any, <:NamedTuple{(:backing,)}}) = base(backing(dx).backing) # Zygote version
7+
58
"""
69
destructure(model) -> vector, reconstructor
710
@@ -55,21 +58,24 @@ Base.length(re::Restructure) = re.length
5558

5659
# This flattens a model, and returns a web of offsets for later use:
5760
function _flatten(x)
58-
isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case
61+
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
5962
arrays = AbstractVector[]
6063
len = Ref(0)
6164
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
62-
push!(arrays, vec(y))
65+
push!(arrays, _vec(y))
6366
o = len[]
6467
len[] = o + length(y)
6568
o
6669
end
6770
reduce(vcat, arrays), off, len[]
6871
end
6972

73+
_vec(x::Number) = LinRange(x,x,1)
74+
_vec(x::AbstractArray) = vec(x)
75+
7076
function ChainRulesCore.rrule(::typeof(_flatten), x)
7177
flat, off, len = _flatten(x)
72-
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT))
78+
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT))
7379
(flat, off, len), _flatten_back
7480
end
7581

@@ -92,7 +98,7 @@ function _trainable_biwalk(f, x, aux)
9298
end
9399

94100
function _trainmap(f, ch, tr, aux)
95-
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)??
101+
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
96102
isnothing(t) ? c : f(t, a)
97103
end
98104
end
@@ -121,7 +127,7 @@ ChainRulesCore.@non_differentiable _zero(x)
121127
# This is the gradient of model reconstruction, accumulating duplicates:
122128
function _grad!(x, dx, off, flat::AbstractVector)
123129
x′, _ = functor(typeof(x), x)
124-
dx′, _ = functor(typeof(x), dx)
130+
dx′, _ = functor(typeof(x), base(dx))
125131
off′, _ = functor(typeof(x), off)
126132
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
127133
flat
@@ -134,7 +140,6 @@ _grad!(x, dx::Zero, off, flat::AbstractVector) = dx
134140
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity
135141

136142
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
137-
println("grad! fwd ", length(flat))
138143
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
139144
_grad!(x, dx, off, flat), _grad_back
140145
end

test/destructure.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
m1 = collect(1:3.0)
33
m2 = (collect(1:3.0), collect(4:6.0))
44
m3 = (x = m1, y = sin, z = collect(4:6.0))
5-
m4 = (x = m1, y = m1, z = collect(4:6.0))
5+
m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
66
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
77
m6 = (a = m1, b = [4.0 + im], c = m1)
88
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
@@ -72,13 +72,24 @@ end
7272
@test g8[3] == [[10.0]]
7373

7474
@testset "second derivative" begin
75-
@test_broken gradient([1,2,3.0]) do v
75+
@test gradient([1,2,3.0]) do v
7676
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
7777
end[1] [8,16,24]
78+
# With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx:
79+
# off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ...
80+
# until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing
81+
# With Zygote, instead:
82+
# dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),)
83+
84+
@test gradient([1,2,3.0]) do v
85+
sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1])
86+
end[1] == [378, 378, 378]
7887

79-
@test_skip gradient([1,2,3.0]) do v
80-
sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1])
81-
end
88+
@test_broken gradient([1,2,3.0]) do v
89+
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1])
90+
end[1] [8,16,24]
91+
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
92+
# Diffractor error in perform_optic_transform
8293
end
8394
end
8495

@@ -109,15 +120,17 @@ end
109120
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
110121

111122
@testset "second derivative" begin
112-
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
113123
@test_broken gradient(collect(1:6.0)) do y
114124
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
115125
end[1] [8,16,24,0,0,0]
116-
# This fixes it!
126+
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
127+
# with Zygote, which can be fixed by:
117128
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,)
118-
@test_skip gradient(collect(1:6.0)) do y
129+
130+
@test_broken gradient(collect(1:6.0)) do y
119131
sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1])
120-
end[1]
132+
end[1] [0,0,0,32,40,48]
133+
# Not fixed by this:
121134
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
122135
end
123136
end

0 commit comments

Comments
 (0)