Skip to content

Commit b62e0a2

Browse files
committed
more... the dimensionmismatch bug is not here
1 parent 756b450 commit b62e0a2

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed

src/destructure.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,24 @@ Base.length(re::Restructure) = re.length
5555

5656
# This flattens a model, and returns a web of offsets for later use:
5757
function _flatten(x)
58-
isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case
58+
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
5959
arrays = AbstractVector[]
6060
len = Ref(0)
6161
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
62-
push!(arrays, vec(y))
62+
push!(arrays, _vec(y))
6363
o = len[]
6464
len[] = o + length(y)
6565
o
6666
end
6767
reduce(vcat, arrays), off, len[]
6868
end
6969

70+
_vec(x::Number) = LinRange(x,x,1)
71+
_vec(x::AbstractArray) = vec(x)
72+
7073
function ChainRulesCore.rrule(::typeof(_flatten), x)
7174
flat, off, len = _flatten(x)
72-
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT))
75+
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT))
7376
(flat, off, len), _flatten_back
7477
end
7578

@@ -92,7 +95,7 @@ function _trainable_biwalk(f, x, aux)
9295
end
9396

9497
function _trainmap(f, ch, tr, aux)
95-
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)??
98+
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
9699
isnothing(t) ? c : f(t, a)
97100
end
98101
end
@@ -121,7 +124,7 @@ ChainRulesCore.@non_differentiable _zero(x)
121124
# This is the gradient of model reconstruction, accumulating duplicates:
122125
function _grad!(x, dx, off, flat::AbstractVector)
123126
x′, _ = functor(typeof(x), x)
124-
dx′, _ = functor(typeof(x), dx)
127+
dx′, _ = functor(typeof(x), base(dx))
125128
off′, _ = functor(typeof(x), off)
126129
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
127130
flat
@@ -134,7 +137,6 @@ _grad!(x, dx::Zero, off, flat::AbstractVector) = dx
134137
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity
135138

136139
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
137-
println("grad! fwd ", length(flat))
138140
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
139141
_grad!(x, dx, off, flat), _grad_back
140142
end

src/interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero
33
base(dx::Tangent) = backing(canonicalize(dx))
4+
base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure))
45
base(dx) = dx
56
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
67

test/destructure.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
m1 = collect(1:3.0)
33
m2 = (collect(1:3.0), collect(4:6.0))
44
m3 = (x = m1, y = sin, z = collect(4:6.0))
5-
m4 = (x = m1, y = m1, z = collect(4:6.0))
5+
m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
66
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
77
m6 = (a = m1, b = [4.0 + im], c = m1)
88
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
@@ -75,10 +75,18 @@ end
7575
@test_broken gradient([1,2,3.0]) do v
7676
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
7777
end[1] [8,16,24]
78+
# With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx:
79+
# off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ...
80+
# until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing
7881

79-
@test_skip gradient([1,2,3.0]) do v
80-
sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1])
81-
end
82+
@test_broken gradient([1,2,3.0]) do v
83+
sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1])
84+
end[1] == [378, 378, 378]
85+
86+
@test_broken gradient([1,2,3.0]) do v
87+
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1])
88+
end[1] [8,16,24]
89+
# Diffractor error in perform_optic_transform
8290
end
8391
end
8492

@@ -109,15 +117,17 @@ end
109117
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
110118

111119
@testset "second derivative" begin
112-
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
113120
@test_broken gradient(collect(1:6.0)) do y
114121
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
115122
end[1] [8,16,24,0,0,0]
116-
# This fixes it!
123+
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
124+
# with Zygote, which can be fixed by:
117125
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,)
118-
@test_skip gradient(collect(1:6.0)) do y
126+
127+
@test_broken gradient(collect(1:6.0)) do y
119128
sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1])
120-
end[1]
129+
end[1] [0,0,0,32,40,48]
130+
# Not fixed by this:
121131
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
122132
end
123133
end

0 commit comments

Comments
 (0)