From f7c1a7f1c543886f3a1a29b983c44849adaad076 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 5 Feb 2022 23:51:43 -0500 Subject: [PATCH 01/13] destructure, take II --- docs/src/api.md | 6 +++ src/Optimisers.jl | 4 +- src/destructure.jl | 108 ++++++++++++++++++++++++++++++++++++++++++++ src/interface.jl | 2 +- test/destructure.jl | 84 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 5 +- 6 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 src/destructure.jl create mode 100644 test/destructure.jl diff --git a/docs/src/api.md b/docs/src/api.md index 5671140b..e43df995 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -42,6 +42,12 @@ optimiser to act on all suitable fields. To restrict this, define `trainable`: Optimisers.trainable ``` +Such restrictions are also obeyed by this function for flattening a model: + +```@docs +Optimisers.destructure +``` + ## Rule Definition ```@docs diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 9f93e041..3ef067f1 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -4,8 +4,10 @@ using Functors: functor, fmap, isleaf using LinearAlgebra include("interface.jl") -include("rules.jl") +include("destructure.jl") +export destructure +include("rules.jl") export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain diff --git a/src/destructure.jl b/src/destructure.jl new file mode 100644 index 00000000..fb678a00 --- /dev/null +++ b/src/destructure.jl @@ -0,0 +1,108 @@ + +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo +const NoT = NoTangent() + +""" + destructure([T], model) -> vector, reconstructor + +Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model +to a `Vector{T}`, and returns also a function which reverses this transformation. +Differentiable. +""" +function destructure(::Type{T}, x) where T + flat, off = alpha!(x, T[]) + len = length(flat) + # flat, newflat -> beta(x, off, newflat; len) + flat, Restucture(x, off, len) +end + +struct Restucture{T,S} + model::T + offsets::S + length::Int +end +(re::Restucture)(flat) = beta(re.model, re.offsets, flat; len = re.length) +Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") + +# This flattens a model, and returns a web of offsets for later use: +function alpha!(x, flat::AbstractVector) + isempty(flat) || error("this won't work") + isnumeric(x) && return append!(flat, x), 0 # trivial case + off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y + append!(flat, y) + length(flat) - length(y) + end + flat, off +end + +function ChainRulesCore.rrule(::typeof(alpha!), x, flat) + flat′, off = alpha!(x, flat) + len = length(flat′) + alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT) + (flat′, off), alpha_back +end + +# This reconstructs either a model like x, or a gradient for it: +function beta(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...) + len == length(flat) || error("wrong length") + fmap(x, off; exclude = isnumeric, walk, kw...) do y, o + _getat(y, o, flat) + end +end + +_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) +_getat(y::AbstractArray, o::Int, flat::AbstractVector) = + ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes + +function _trainable_biwalk(f, x, aux) + ch, re = functor(typeof(x), x) + au, _ = functor(typeof(x), aux) + trainmap(f, ch, _trainable(x), au) |> re +end + +function trainmap(f, ch, tr, aux) + map(ch, tr, aux) do c, t, a + isnothing(t) ? c : f(t, a) + end +end + +function _Tangent_biwalk(f, x, aux) # use with prune = true + ch, re = functor(typeof(x), x) + au, _ = functor(typeof(x), aux) + y = trainmap(f, ch, _trainable(x), au) + y isa Tuple{} && return NoT + Tangent{typeof(x), typeof(y)}(y) +end +# _Tangent_biwalk(f, x::Tuple{}, aux) = NoT + +function ChainRulesCore.rrule(::typeof(beta), x, off, flat; len) + dflat = map!(zero, similar(flat, float(eltype(flat))), flat) + beta_back(dx) = (NoT, NoT, NoT, gamma!(x, dx, off, dflat)) + beta(x, off, flat; len), beta_back +end + +# This is the gradient of model reconstruction, accumulating duplicates: +function gamma!(x, dx, off, flat::AbstractVector) + x′, _ = functor(typeof(x), x) + dx′, _ = functor(typeof(x), dx) + off′, _ = functor(typeof(x), off) + foreach((xᵢ, dxᵢ, oᵢ) -> gamma!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) + flat +end +function gamma!(x, dx, off::Integer, flat::AbstractVector) + @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes, hence no fmap. + flat +end +gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing +gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity + +# Least importantly, this infers the eltype if one is not given: +destructure(x) = destructure(omega(x), x) +function omega(x) + T = Bool + fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y + T = promote_type(T, eltype(y)) + end + T +end +ChainRulesCore.@non_differentiable omega(::Any) diff --git a/src/interface.jl b/src/interface.jl index 80f87dcc..4864ae16 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -70,7 +70,7 @@ trainable(x) = functor(x)[1] _trainable(x) = _trainable(functor(x)[1], trainable(x)) _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) -_trainable(ch::Tuple, tr::Tuple) = tr +_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple @warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" map(c -> c in tr ? c : nothing, ch) diff --git a/test/destructure.jl b/test/destructure.jl new file mode 100644 index 00000000..dd5103c6 --- /dev/null +++ b/test/destructure.jl @@ -0,0 +1,84 @@ + +m1 = collect(1:3.0) +m2 = (collect(1:3.0), collect(4:6.0)) +m3 = (x = m1, y = sin, z = collect(4:6.0)) +m4 = (x = m1, y = m1, z = collect(4:6.0)) +m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) +m6 = (a = m1, b = [4.0 + im], c = m1) +m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) + +@testset "flatten & restore" begin + @test destructure(Int, m1)[1] isa Vector{Int} + @test destructure(m1)[1] isa Vector{Float64} + + @test destructure(m1)[1] == 1:3 + @test destructure(m2)[1] == 1:6 + @test destructure(m3)[1] == 1:6 + @test destructure(m4)[1] == 1:6 + @test destructure(m5)[1] == vcat(1:6, 4:6) + @test destructure(m6)[1] == vcat(1:3, 4 + im) + + @test destructure(m1)[2](7:9) == [7,8,9] + @test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9]) + @test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9]) + m4′ = destructure(m4)[2](4:9) + @test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9]) + @test m4′.x === m4′.y + m5′ = destructure(m5)[2](reverse(1:9)) + @test m5′.a[1].x === m5′.b[1] + @test m5′.b[2] === false + m6′ = destructure(m6)[2]((4:7) .+ (1:4) .* im) + @test m6′.a isa Vector{Float64} + @test m6′.a == 4:6 + @test m6′.a === m6′.c + @test m6′.b == [7 + 4im] + + @test destructure(m7)[1] == 1:3 + m7′ = destructure(m7)[2]([10,20,30]) + @test m7′.a == (sin, [10,20,30]) + @test m7′.b == (cos, [4,5,6]) + @test m7′.c == (tan, [7,8,9]) + + @test_throws Exception destructure(m7)[2]([10,20]) + @test_throws Exception destructure(m7)[2]([10,20,30,40]) +end + +@testset "gradient of flatten" begin + @test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] + @test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) + @test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) + @test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) + @test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) + + g5 = gradient(m -> destructure(m)[1][3], m5)[1] + @test g5.a[1].x == [0,0,1] + @test g5.a[2] === nothing + + g6 = gradient(m -> imag(destructure(m)[1][4]), m6)[1] + @test g6.a == [0,0,0] + @test g6.a isa Vector{Float64} + @test g6.b == [0+im] +end + +@testset "gradient of rebuild" begin + re1 = destructure(m1)[2] + @test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] + re2 = destructure(m2)[2] + @test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] + re3 = destructure(m3)[2] + @test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] + @test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] + + re4 = destructure(m4)[2] + @test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] + @test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] + @test gradient(rand(6)) do x + m = re4(x) + m.x[1] + 2*m.y[2] + 3*m.z[3] + end[1] == [1,2,0, 0,0,3] + + re7 = destructure(m7)[2] + @test gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] + @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] + @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] +end diff --git a/test/runtests.jl b/test/runtests.jl index 825d977e..55068739 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -164,8 +164,11 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) @test_throws ArgumentError Optimisers.setup(ADAMW(), m2) end - @info "finished feature testing" end + @testset verbose=true "Optimisation Rules" begin + include("destructure.jl") + end + @info "finished feature testing" @testset verbose=true "Optimisation Rules" begin include("rules.jl") end From 5a18607221ba032f8500d74b6955e3e30827f72e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 6 Feb 2022 01:19:33 -0500 Subject: [PATCH 02/13] add a test --- test/destructure.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/destructure.jl b/test/destructure.jl index dd5103c6..c8a3c1f2 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -82,3 +82,33 @@ end @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] end + +@testset "Flux issue 1826" begin + v, re = destructure((x=[1,2.0], y=[3,4,5.0])) + @test gradient(zero(v)) do w + m = re(w) + 5 * sum(m.x) + 7 * sum(m[2]) # uses both x and y + end == ([5.0, 5.0, 7.0, 7.0, 7.0],) + # This, using only x, was broken on Flux: + @test gradient(w -> sum(re(w).x), zero(v)) == ([1.0, 1.0, 0.0, 0.0, 0.0],) + + sh = [7,7.0]; + v, re = destructure((x=sh, y=[3.0,4.0], z=sh)) # shared array in the model + @test v == [7, 7, 3, 4] + @test re([1,10,100,1000]) == (x = [1, 10], y = [100, 1000], z = [1, 10]) + + @test gradient(zero(v)) do w + m = re(w) + 3 * sum(m.x) + 13 * sum(m.z) # no dependence on y, but two distinct gradient arrays + end == ([16, 16, 0, 0],) # Flux gave ([3.0, 3.0, 13.0, 13.0],) + + @test gradient(zero(v)) do w + m = re(w) + 4(sum(m.x) + sum(m.z)) # now two gradients are ===, so it eliminates one + end == ([8,8,0,0],) + + @test gradient(zero(v)) do w + m = re(w) + 4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one + end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],) +end From b70875fffc638bb339d3c9386ab900a8035c40e9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 6 Feb 2022 01:19:43 -0500 Subject: [PATCH 03/13] tidy --- src/destructure.jl | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index fb678a00..ffff14b5 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -8,11 +8,18 @@ const NoT = NoTangent() Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model to a `Vector{T}`, and returns also a function which reverses this transformation. Differentiable. + +# Example +```jldoctest +julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0]))) +([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3)) + +julia> re([10,20,30]) +(x = [10.0, 20.0], y = (sin, [30.0])) +``` """ function destructure(::Type{T}, x) where T - flat, off = alpha!(x, T[]) - len = length(flat) - # flat, newflat -> beta(x, off, newflat; len) + flat, off, len = alpha!(x, T[]) flat, Restucture(x, off, len) end @@ -32,14 +39,13 @@ function alpha!(x, flat::AbstractVector) append!(flat, y) length(flat) - length(y) end - flat, off + flat, off, length(flat) end function ChainRulesCore.rrule(::typeof(alpha!), x, flat) - flat′, off = alpha!(x, flat) - len = length(flat′) + flat′, off, len = alpha!(x, flat) alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT) - (flat′, off), alpha_back + (flat, off, len), alpha_back end # This reconstructs either a model like x, or a gradient for it: @@ -73,7 +79,6 @@ function _Tangent_biwalk(f, x, aux) # use with prune = true y isa Tuple{} && return NoT Tangent{typeof(x), typeof(y)}(y) end -# _Tangent_biwalk(f, x::Tuple{}, aux) = NoT function ChainRulesCore.rrule(::typeof(beta), x, off, flat; len) dflat = map!(zero, similar(flat, float(eltype(flat))), flat) From e325f666678d868fa6cc09a3c52ebc1b35f21144 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 6 Feb 2022 22:32:36 -0500 Subject: [PATCH 04/13] replace append! with reduce(vcat, ...) --- src/destructure.jl | 40 ++++++++++++++++------------------------ test/destructure.jl | 2 -- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index ffff14b5..3be43f6c 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -3,10 +3,10 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo const NoT = NoTangent() """ - destructure([T], model) -> vector, reconstructor + destructure(model) -> vector, reconstructor Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model -to a `Vector{T}`, and returns also a function which reverses this transformation. +to a vector, and returns also a function which reverses this transformation. Differentiable. # Example @@ -18,8 +18,8 @@ julia> re([10,20,30]) (x = [10.0, 20.0], y = (sin, [30.0])) ``` """ -function destructure(::Type{T}, x) where T - flat, off, len = alpha!(x, T[]) +function destructure(x) + flat, off, len = alpha(x) flat, Restucture(x, off, len) end @@ -32,19 +32,22 @@ end Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") # This flattens a model, and returns a web of offsets for later use: -function alpha!(x, flat::AbstractVector) - isempty(flat) || error("this won't work") - isnumeric(x) && return append!(flat, x), 0 # trivial case +function alpha(x) + isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case + arrays = AbstractVector[] + len = Ref(0) off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y - append!(flat, y) - length(flat) - length(y) + push!(arrays, vec(y)) + o = len[] + len[] = o + length(y) + o end - flat, off, length(flat) + reduce(vcat, arrays), off, len[] end -function ChainRulesCore.rrule(::typeof(alpha!), x, flat) - flat′, off, len = alpha!(x, flat) - alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT) +function ChainRulesCore.rrule(::typeof(alpha), x) + flat, off, len = alpha(x) + alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len)) (flat, off, len), alpha_back end @@ -100,14 +103,3 @@ function gamma!(x, dx, off::Integer, flat::AbstractVector) end gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity - -# Least importantly, this infers the eltype if one is not given: -destructure(x) = destructure(omega(x), x) -function omega(x) - T = Bool - fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y - T = promote_type(T, eltype(y)) - end - T -end -ChainRulesCore.@non_differentiable omega(::Any) diff --git a/test/destructure.jl b/test/destructure.jl index c8a3c1f2..ad685c7b 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -8,9 +8,7 @@ m6 = (a = m1, b = [4.0 + im], c = m1) m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) @testset "flatten & restore" begin - @test destructure(Int, m1)[1] isa Vector{Int} @test destructure(m1)[1] isa Vector{Float64} - @test destructure(m1)[1] == 1:3 @test destructure(m2)[1] == 1:6 @test destructure(m3)[1] == 1:6 From c686fc5e386523b0d45730740f39c7fc5f57617a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 6 Feb 2022 22:53:14 -0500 Subject: [PATCH 05/13] testset names --- test/destructure.jl | 2 +- test/rules.jl | 8 ++++---- test/runtests.jl | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/destructure.jl b/test/destructure.jl index ad685c7b..54b368e7 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -7,7 +7,7 @@ m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) m6 = (a = m1, b = [4.0 + im], c = m1) m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) -@testset "flatten & restore" begin +@testset "flatten & rebuild" begin @test destructure(m1)[1] isa Vector{Float64} @test destructure(m1)[1] == 1:3 @test destructure(m2)[1] == 1:6 diff --git a/test/rules.jl b/test/rules.jl index c8697683..ffb4ca65 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -44,7 +44,7 @@ end end end -@testset verbose=true "simple sum" begin +@testset "simple sum" begin empty!(LOG) @testset "$(name(o))" for o in RULES m = shuffle!(reshape(1:64, 8, 8) .+ 0.0) @@ -79,7 +79,7 @@ end end end -@testset verbose=true "StaticArrays" begin +@testset "StaticArrays" begin empty!(LOG) @testset "$(name(o))" for o in RULES W1 = @SMatrix randn(10, 10) @@ -157,7 +157,7 @@ end end end -@testset verbose=true "mutation check" begin +@testset "mutation check" begin # If @lazy captures a matrix which is later mutated, the results won't agree here: @testset "$(name(o))" for o in RULES model = Float64.(rand(Int8, 8)) @@ -174,7 +174,7 @@ end end end -@testset "with complex numebers: Flux#1776" begin +@testset "with complex numbers: Flux#1776" begin empty!(LOG) @testset "$(name(opt))" for opt in [ # The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: diff --git a/test/runtests.jl b/test/runtests.jl index 55068739..a60dacd2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -165,7 +165,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) end end - @testset verbose=true "Optimisation Rules" begin + @testset verbose=true "Destructure" begin include("destructure.jl") end @info "finished feature testing" From 520efbe9cc9177b9465930682d1dcd7911cf3ebc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 8 Feb 2022 22:14:43 -0500 Subject: [PATCH 06/13] rename everything --- Project.toml | 2 +- docs/src/api.md | 1 + src/Optimisers.jl | 3 +- src/destructure.jl | 73 ++++++++++++++++++++++++++++++---------------- 4 files changed, 52 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index d91c01d2..66c062ed 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Functors = "0.2.7" +Functors = "0.2.8" julia = "1.6" [extras] diff --git a/docs/src/api.md b/docs/src/api.md index e43df995..edd8be32 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -46,6 +46,7 @@ Such restrictions are also obeyed by this function for flattening a model: ```@docs Optimisers.destructure +Optimisers.Restructure ``` ## Rule Definition diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 3ef067f1..417b90d4 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -4,8 +4,9 @@ using Functors: functor, fmap, isleaf using LinearAlgebra include("interface.jl") + include("destructure.jl") -export destructure +export destructure, total, total2 include("rules.jl") export Descent, ADAM, Momentum, Nesterov, RMSProp, diff --git a/src/destructure.jl b/src/destructure.jl index 3be43f6c..7d2cfa6e 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -19,20 +19,42 @@ julia> re([10,20,30]) ``` """ function destructure(x) - flat, off, len = alpha(x) - flat, Restucture(x, off, len) + flat, off, len = _flatten(x) + flat, Restructure(x, off, len) end -struct Restucture{T,S} +""" + Restructure(Model, ..., length) + +This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with +new parameters from vector `p`. If the model is callable, then `re(x, p)` . + +# Example +```julia +julia> using Flux, Optimisers + +julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid)) +([1, 3, 2, 4, 0, 0], Restructure(Dense, ..., 6)) + +julia> m = re(-4:1) +Dense(2, 2, σ) # 6 parameters + +julia> m([0.2, 0.3]) ≈ re([0.2, 0.3], -4:1) +true +``` +""" +struct Restructure{T,S} model::T offsets::S length::Int end -(re::Restucture)(flat) = beta(re.model, re.offsets, flat; len = re.length) -Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") +(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat; len = re.length) +(re::Restructure)(x, flat::AbstractVector) = re(flat)(x) +Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") +Base.length(re::Restructure) = re.length # This flattens a model, and returns a web of offsets for later use: -function alpha(x) +function _flatten(x) isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case arrays = AbstractVector[] len = Ref(0) @@ -45,14 +67,14 @@ function alpha(x) reduce(vcat, arrays), off, len[] end -function ChainRulesCore.rrule(::typeof(alpha), x) - flat, off, len = alpha(x) - alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len)) - (flat, off, len), alpha_back +function ChainRulesCore.rrule(::typeof(_flatten), x) + flat, off, len = _flatten(x) + _flatten_back((dflat, _)) = (NoT, _rebuild(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len)) + (flat, off, len), _flatten_back end # This reconstructs either a model like x, or a gradient for it: -function beta(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...) +function _rebuild(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...) len == length(flat) || error("wrong length") fmap(x, off; exclude = isnumeric, walk, kw...) do y, o _getat(y, o, flat) @@ -66,40 +88,41 @@ _getat(y::AbstractArray, o::Int, flat::AbstractVector) = function _trainable_biwalk(f, x, aux) ch, re = functor(typeof(x), x) au, _ = functor(typeof(x), aux) - trainmap(f, ch, _trainable(x), au) |> re + _trainmap(f, ch, _trainable(x), au) |> re end -function trainmap(f, ch, tr, aux) - map(ch, tr, aux) do c, t, a +function _trainmap(f, ch, tr, aux) + map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)?? isnothing(t) ? c : f(t, a) end end -function _Tangent_biwalk(f, x, aux) # use with prune = true +function _Tangent_biwalk(f, x, aux) # use with prune = NoT ch, re = functor(typeof(x), x) au, _ = functor(typeof(x), aux) - y = trainmap(f, ch, _trainable(x), au) + y = _trainmap(f, ch, _trainable(x), au) y isa Tuple{} && return NoT Tangent{typeof(x), typeof(y)}(y) end -function ChainRulesCore.rrule(::typeof(beta), x, off, flat; len) +function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len) dflat = map!(zero, similar(flat, float(eltype(flat))), flat) - beta_back(dx) = (NoT, NoT, NoT, gamma!(x, dx, off, dflat)) - beta(x, off, flat; len), beta_back + _rebuild_back(dx) = (NoT, NoT, NoT, _accumulate!(x, dx, off, dflat)) + _rebuild(x, off, flat; len), _rebuild_back end # This is the gradient of model reconstruction, accumulating duplicates: -function gamma!(x, dx, off, flat::AbstractVector) +function _accumulate!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) dx′, _ = functor(typeof(x), dx) off′, _ = functor(typeof(x), off) - foreach((xᵢ, dxᵢ, oᵢ) -> gamma!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) + foreach((xᵢ, dxᵢ, oᵢ) -> _accumulate!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) flat end -function gamma!(x, dx, off::Integer, flat::AbstractVector) - @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes, hence no fmap. +function _accumulate!(x, dx, off::Integer, flat::AbstractVector) + @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes flat end -gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing -gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity +_accumulate!(x, dx::Zero, off, flat::AbstractVector) = nothing +_accumulate!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity + From af14f84d69dd2b88f0786029dd6da7a9eff7fe48 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 9 Feb 2022 19:02:21 -0500 Subject: [PATCH 07/13] tweak --- src/destructure.jl | 24 ++++++++++++------------ test/runtests.jl | 1 - 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index 7d2cfa6e..fb900e82 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -1,5 +1,5 @@ -using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk const NoT = NoTangent() """ @@ -11,11 +11,11 @@ Differentiable. # Example ```jldoctest -julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0]))) -([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3)) +julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im]))) +(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3)) -julia> re([10,20,30]) -(x = [10.0, 20.0], y = (sin, [30.0])) +julia> re([3, 5-im, 7+11im]) +(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im])) ``` """ function destructure(x) @@ -27,7 +27,7 @@ end Restructure(Model, ..., length) This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with -new parameters from vector `p`. If the model is callable, then `re(x, p)` . +new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`. # Example ```julia @@ -107,22 +107,22 @@ end function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len) dflat = map!(zero, similar(flat, float(eltype(flat))), flat) - _rebuild_back(dx) = (NoT, NoT, NoT, _accumulate!(x, dx, off, dflat)) + _rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, dflat)) _rebuild(x, off, flat; len), _rebuild_back end # This is the gradient of model reconstruction, accumulating duplicates: -function _accumulate!(x, dx, off, flat::AbstractVector) +function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) dx′, _ = functor(typeof(x), dx) off′, _ = functor(typeof(x), off) - foreach((xᵢ, dxᵢ, oᵢ) -> _accumulate!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) + foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) flat end -function _accumulate!(x, dx, off::Integer, flat::AbstractVector) +function _grad!(x, dx, off::Integer, flat::AbstractVector) @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes flat end -_accumulate!(x, dx::Zero, off, flat::AbstractVector) = nothing -_accumulate!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity +_grad!(x, dx::Zero, off, flat::AbstractVector) = nothing +_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity diff --git a/test/runtests.jl b/test/runtests.jl index a60dacd2..ef216458 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -168,7 +168,6 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) @testset verbose=true "Destructure" begin include("destructure.jl") end - @info "finished feature testing" @testset verbose=true "Optimisation Rules" begin include("rules.jl") end From 6f3eefab3012fc5f6e3c9f2c9a8c7cec3103d0e2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 9 Feb 2022 21:36:20 -0500 Subject: [PATCH 08/13] two broken tests --- test/destructure.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/destructure.jl b/test/destructure.jl index 54b368e7..f53f9e28 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -56,6 +56,11 @@ end @test g6.a == [0,0,0] @test g6.a isa Vector{Float64} @test g6.b == [0+im] + + # Second derivative -- no method matching rrule(::typeof(Optimisers._rebuild), ...? + @test_broken gradient([1,2,3]) do v + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6]))[1][1]) + end[1] ≈ [8,16,24] end @testset "gradient of rebuild" begin @@ -79,6 +84,11 @@ end @test gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] + + # Second derivative -- error from _tryaxes(x::Tangent) in Zygote's map rule + @test_broken gradient(collect(1:6)) do y + sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) + end[1] ≈ [8,16,24,0,0,0] end @testset "Flux issue 1826" begin From 17b57f04546101eb44046f8ab05ba2fd90c4ca01 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Feb 2022 18:15:04 -0500 Subject: [PATCH 09/13] make len positional, fix a bug --- src/destructure.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index fb900e82..71b25c65 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -48,7 +48,7 @@ struct Restructure{T,S} offsets::S length::Int end -(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat; len = re.length) +(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length) (re::Restructure)(x, flat::AbstractVector) = re(flat)(x) Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") Base.length(re::Restructure) = re.length @@ -69,13 +69,13 @@ end function ChainRulesCore.rrule(::typeof(_flatten), x) flat, off, len = _flatten(x) - _flatten_back((dflat, _)) = (NoT, _rebuild(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len)) + _flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT)) (flat, off, len), _flatten_back end # This reconstructs either a model like x, or a gradient for it: -function _rebuild(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...) - len == length(flat) || error("wrong length") +function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trainable_biwalk, kw...) + len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))")) fmap(x, off; exclude = isnumeric, walk, kw...) do y, o _getat(y, o, flat) end @@ -105,12 +105,14 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT Tangent{typeof(x), typeof(y)}(y) end -function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len) - dflat = map!(zero, similar(flat, float(eltype(flat))), flat) - _rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, dflat)) - _rebuild(x, off, flat; len), _rebuild_back +function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...) + _rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT) + _rebuild(x, off, flat, len; kw...), _rebuild_back end +_zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad! +ChainRulesCore.@non_differentiable _zero(x) + # This is the gradient of model reconstruction, accumulating duplicates: function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) From 337f365ac54c9dce73daf7baca47ceb8c76e2bee Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Feb 2022 18:16:27 -0500 Subject: [PATCH 10/13] second derivatives --- src/destructure.jl | 9 +++++++-- test/destructure.jl | 29 +++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index 71b25c65..8e6b75dd 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -125,6 +125,11 @@ function _grad!(x, dx, off::Integer, flat::AbstractVector) @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes flat end -_grad!(x, dx::Zero, off, flat::AbstractVector) = nothing -_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity +_grad!(x, dx::Zero, off, flat::AbstractVector) = dx +_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity +function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) + println("grad! fwd ", length(flat)) + _grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT) + _grad!(x, dx, off, flat), _grad_back +end diff --git a/test/destructure.jl b/test/destructure.jl index f53f9e28..55ab37df 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -57,10 +57,15 @@ end @test g6.a isa Vector{Float64} @test g6.b == [0+im] - # Second derivative -- no method matching rrule(::typeof(Optimisers._rebuild), ...? - @test_broken gradient([1,2,3]) do v - sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6]))[1][1]) - end[1] ≈ [8,16,24] + @testset "second derivative" begin + @test_broken gradient([1,2,3.0]) do v + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) + end[1] ≈ [8,16,24] + + @test_skip gradient([1,2,3.0]) do v + sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1]) + end + end end @testset "gradient of rebuild" begin @@ -85,10 +90,18 @@ end @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] - # Second derivative -- error from _tryaxes(x::Tangent) in Zygote's map rule - @test_broken gradient(collect(1:6)) do y - sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) - end[1] ≈ [8,16,24,0,0,0] + @testset "second derivative" begin + # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} + @test_broken gradient(collect(1:6.0)) do y + sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) + end[1] ≈ [8,16,24,0,0,0] + # This fixes it! + # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) + @test_skip gradient(collect(1:6.0)) do y + sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) + end[1] + # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) + end end @testset "Flux issue 1826" begin From 756b4502c2b5f1909dee21ea5287a1b3bfb2a4d2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Feb 2022 21:59:27 -0500 Subject: [PATCH 11/13] arrays of arrays --- src/destructure.jl | 7 ++++++- src/interface.jl | 1 + test/destructure.jl | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/destructure.jl b/src/destructure.jl index 8e6b75dd..75e1876a 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -102,7 +102,12 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT au, _ = functor(typeof(x), aux) y = _trainmap(f, ch, _trainable(x), au) y isa Tuple{} && return NoT - Tangent{typeof(x), typeof(y)}(y) + p = ProjectTo(x) + if p isa ProjectTo # e.g. Array, NamedTuple + p(y) + else # p === identity for unknown structs + Tangent{typeof(x), typeof(y)}(y) + end end function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...) diff --git a/src/interface.jl b/src/interface.jl index 4864ae16..1116b90a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -71,6 +71,7 @@ trainable(x) = functor(x)[1] _trainable(x) = _trainable(functor(x)[1], trainable(x)) _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) _trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr +_trainable(ch::AbstractArray, tr::AbstractArray) = tr function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple @warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" map(c -> c in tr ? c : nothing, ch) diff --git a/test/destructure.jl b/test/destructure.jl index 55ab37df..6de5a6af 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -6,6 +6,7 @@ m4 = (x = m1, y = m1, z = collect(4:6.0)) m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) m6 = (a = m1, b = [4.0 + im], c = m1) m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) +m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] @testset "flatten & rebuild" begin @test destructure(m1)[1] isa Vector{Float64} @@ -31,12 +32,20 @@ m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0 @test m6′.a === m6′.c @test m6′.b == [7 + 4im] + # struct, trainable @test destructure(m7)[1] == 1:3 m7′ = destructure(m7)[2]([10,20,30]) @test m7′.a == (sin, [10,20,30]) @test m7′.b == (cos, [4,5,6]) @test m7′.c == (tan, [7,8,9]) + @test destructure(m8)[1] == 1:5 + m8′ = destructure(m8)[2](1:5) + @test m8′[1].x === m8′[1].y + @test m8′[2].b.y === false + @test m8′[3][1] == [5.0] + + # errors @test_throws Exception destructure(m7)[2]([10,20]) @test_throws Exception destructure(m7)[2]([10,20,30,40]) end @@ -57,6 +66,11 @@ end @test g6.a isa Vector{Float64} @test g6.b == [0+im] + g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1] + @test g8[1].x == [2,4,6] + @test g8[2].b.x == [8] + @test g8[3] == [[10.0]] + @testset "second derivative" begin @test_broken gradient([1,2,3.0]) do v sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) @@ -90,6 +104,10 @@ end @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] + v8, re8 = destructure(m8) + @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] + @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] + @testset "second derivative" begin # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} @test_broken gradient(collect(1:6.0)) do y From d95a1472616759cff6c19c6143c3a594383dfb97 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 11 Feb 2022 10:11:17 -0500 Subject: [PATCH 12/13] more... the dimensionmismatch bug is not here --- src/destructure.jl | 17 +++++++++++------ test/destructure.jl | 31 ++++++++++++++++++++++--------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index 75e1876a..92b4be68 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -2,6 +2,9 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk const NoT = NoTangent() +base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure)) +base(dx::Tangent{Any, <:NamedTuple{(:backing,)}}) = base(backing(dx).backing) # Zygote version + """ destructure(model) -> vector, reconstructor @@ -55,11 +58,11 @@ Base.length(re::Restructure) = re.length # This flattens a model, and returns a web of offsets for later use: function _flatten(x) - isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case + isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case arrays = AbstractVector[] len = Ref(0) off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y - push!(arrays, vec(y)) + push!(arrays, _vec(y)) o = len[] len[] = o + length(y) o @@ -67,9 +70,12 @@ function _flatten(x) reduce(vcat, arrays), off, len[] end +_vec(x::Number) = LinRange(x,x,1) +_vec(x::AbstractArray) = vec(x) + function ChainRulesCore.rrule(::typeof(_flatten), x) flat, off, len = _flatten(x) - _flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT)) + _flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT)) (flat, off, len), _flatten_back end @@ -92,7 +98,7 @@ function _trainable_biwalk(f, x, aux) end function _trainmap(f, ch, tr, aux) - map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)?? + map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c) isnothing(t) ? c : f(t, a) end end @@ -121,7 +127,7 @@ ChainRulesCore.@non_differentiable _zero(x) # This is the gradient of model reconstruction, accumulating duplicates: function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) - dx′, _ = functor(typeof(x), dx) + dx′, _ = functor(typeof(x), base(dx)) off′, _ = functor(typeof(x), off) foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) flat @@ -134,7 +140,6 @@ _grad!(x, dx::Zero, off, flat::AbstractVector) = dx _grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) - println("grad! fwd ", length(flat)) _grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT) _grad!(x, dx, off, flat), _grad_back end diff --git a/test/destructure.jl b/test/destructure.jl index 6de5a6af..40c4360c 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -2,7 +2,7 @@ m1 = collect(1:3.0) m2 = (collect(1:3.0), collect(4:6.0)) m3 = (x = m1, y = sin, z = collect(4:6.0)) -m4 = (x = m1, y = m1, z = collect(4:6.0)) +m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) m6 = (a = m1, b = [4.0 + im], c = m1) m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) @@ -72,13 +72,24 @@ end @test g8[3] == [[10.0]] @testset "second derivative" begin - @test_broken gradient([1,2,3.0]) do v + @test gradient([1,2,3.0]) do v sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) end[1] ≈ [8,16,24] + # With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx: + # off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ... + # until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing + # With Zygote, instead: + # dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),) + + @test gradient([1,2,3.0]) do v + sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1]) + end[1] == [378, 378, 378] - @test_skip gradient([1,2,3.0]) do v - sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1]) - end + @test_broken gradient([1,2,3.0]) do v + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1]) + end[1] ≈ [8,16,24] + # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) + # Diffractor error in perform_optic_transform end end @@ -109,15 +120,17 @@ end @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] @testset "second derivative" begin - # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} @test_broken gradient(collect(1:6.0)) do y sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) end[1] ≈ [8,16,24,0,0,0] - # This fixes it! + # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} + # with Zygote, which can be fixed by: # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) - @test_skip gradient(collect(1:6.0)) do y + + @test_broken gradient(collect(1:6.0)) do y sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) - end[1] + end[1] ≈ [0,0,0,32,40,48] + # Not fixed by this: # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) end end From 65e136e48e57b2676dd1c85ed2b1c8302aaadb34 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 11 Feb 2022 18:53:45 -0500 Subject: [PATCH 13/13] warnings --- src/destructure.jl | 13 ++++++++++--- src/interface.jl | 2 +- test/runtests.jl | 1 + 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index 92b4be68..3ace52ec 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -2,9 +2,6 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk const NoT = NoTangent() -base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure)) -base(dx::Tangent{Any, <:NamedTuple{(:backing,)}}) = base(backing(dx).backing) # Zygote version - """ destructure(model) -> vector, reconstructor @@ -75,6 +72,7 @@ _vec(x::AbstractArray) = vec(x) function ChainRulesCore.rrule(::typeof(_flatten), x) flat, off, len = _flatten(x) + _maybewarn() _flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT)) (flat, off, len), _flatten_back end @@ -139,7 +137,16 @@ end _grad!(x, dx::Zero, off, flat::AbstractVector) = dx _grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity +# These are only needed for 2nd derivatives: function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) + @warn "second derivatives of Restructure may not work yet, sorry!" maxlog=3 _grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT) _grad!(x, dx, off, flat), _grad_back end +base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure)) +base(dx::Tangent{Any, <:NamedTuple{(:backing,)}}) = base(backing(dx).backing) # Zygote version +_maybewarn() = nothing +function ChainRulesCore.rrule(::typeof(_maybewarn)) + @warn "second derivatives of destructure may not work yet, sorry!" maxlog=3 + nothing, _ -> (NoT,) +end diff --git a/src/interface.jl b/src/interface.jl index 1116b90a..235c2e94 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -73,7 +73,7 @@ _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) _trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr _trainable(ch::AbstractArray, tr::AbstractArray) = tr function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple - @warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" + @warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3 map(c -> c in tr ? c : nothing, ch) end diff --git a/test/runtests.jl b/test/runtests.jl index ef216458..d47bce08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,6 +80,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) end @testset "trainable subset" begin + @info "ignore these warnings about trainable, testing the old path" # Foo has an old-style tuple trainable, both elements mf = Foo([1.0, 2.0], (a = sin, b = [3.0, 4.0], c = 5)) sf = Optimisers.setup(Descent(0.1), mf)