From 4dc70b5232f5f4d5c662e14a02b0681c76e62724 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 12 Oct 2021 00:01:43 +0200 Subject: [PATCH 1/5] destructure returns only trainable params --- src/Flux.jl | 1 + src/functor.jl | 119 +++++++++++++++++++++++++++++++- src/layers/basic.jl | 2 +- src/layers/show.jl | 4 +- src/utils.jl | 53 -------------- test/functor.jl | 165 ++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++ test/utils.jl | 17 ----- 8 files changed, 291 insertions(+), 74 deletions(-) create mode 100644 test/functor.jl diff --git a/src/Flux.jl b/src/Flux.jl index 80d999de38..1519bfa689 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,6 +8,7 @@ using Zygote, MacroTools, Juno, Reexport using MacroTools: @forward @reexport using NNlib using Zygote: Params, @adjoint, gradient, pullback, @nograd +using Functors: Functors, @functor, functor, fmap export gradient export Chain, Dense, Maxout, SkipConnection, Parallel, flatten, diff --git a/src/functor.jl b/src/functor.jl index bef3559c2f..59144c4208 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,7 +1,6 @@ import Adapt: adapt, adapt_storage using LinearAlgebra: Cholesky using Zygote: IdSet -import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray trainable(m) = functor(m)[1] @@ -38,6 +37,124 @@ Possible values include: """ trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) + +# Flattening models to weight vectors, and back + +function _restructure(m, xs) + i = 0 + filter = (x, c) -> any(y -> c === y, trainable(x)) + walk = filtered_walk(filter) + m̄ = fmap(m; walk) do x + x isa AbstractArray{<:Number} || return x + x = reshape(xs[i .+ (1:length(x))], size(x)) + i += length(x) + return x + end + length(xs) == i || @warn "Expected $(i) params, got $(length(xs))" + return m̄ +end + +@adjoint function _restructure(m, xs) + m̄, numel = _restructure(m, xs), length(xs) + function _restructure_pullback(dm) + xs′ = destructure(dm)[1] + numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))" + return (nothing, xs′) + end + return m̄, _restructure_pullback +end + +""" + destructure(m) +Flatten a model's parameters into a single weight vector. + julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) + Chain(Dense(10, 5, σ), Dense(5, 2), softmax) + julia> θ, re = destructure(m); + julia> θ + 67-element Vector{Float32}: + -0.1407104 + ... +The second return value `re` allows you to reconstruct the original network after making +modifications to the weight vector (for example, with a hypernetwork). + julia> re(θ .* 2) + Chain(Dense(10, 5, σ), Dense(5, 2), softmax) +""" +function destructure(m) + xs = Zygote.Buffer([]) + collect_params!(xs, m) + return vcat(vec.(copy(xs))...), p -> _restructure(m, p) +end + +function collect_params!(xs, m) + filter = (x, c) -> any(y -> c === y, trainable(x)) + walk = filtered_walk(filter) + fmap(m; walk) do x + x isa AbstractArray{<:Number} && push!(xs, x) + return x + end +end + +function filtered_walk(filter) + seen = IdSet() + + function walk(f, x) + x in seen && return x + push!(seen, x) + + children, reconstruct = functor(x) + mappedchildren = map(children) do c + filter(x, c) ? f(c) : c + end + reconstruct(mappedchildren) + end + + return walk +end + + +""" + params(m...) + +Collect trainable parameters (a.k.a. numerical arrays) +from the input model(s) `m` into a [`Zygote.Params`](@ref) object. + +Only the parameters that can be reached by recursion +on the [`trainable`](@ref) children of +the tree with root `m` are collected. + +# Usage + +```julia-repl +julia> m = Dense(ones(2, 3), zeros(2)) +Dense(3, 2) # 8 parameters + +julia> ps = Flux.params(m) +Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]]) + +julia> x = ones(3) +3-element Vector{Float64}: + 1.0 + 1.0 + 1.0 + +julia> gs = gradient(() -> sum(2 .* m(x)), ps) +Grads(...) + +julia> gs[m.weight] +2×3 Matrix{Float64}: + 2.0 2.0 2.0 + 2.0 2.0 2.0 +``` +""" +function params end + +## TODO This causes some test regressions. Why? +# function params(m...) +# ps = Params() +# collect_params!(ps, m) +# return ps +# end + params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e40457ef53..7b7c2285fe 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -41,7 +41,7 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys -functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...) +Functors.functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...) applychain(::Tuple{}, x) = x applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) diff --git a/src/layers/show.jl b/src/layers/show.jl index 791d2511ca..121fe2a715 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -43,7 +43,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end -_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: +_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for: _show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell _show_leaflike(::Diagonal) = true # appears inside LayerNorm @@ -97,7 +97,7 @@ function _big_finale(io::IO, m) end _childarray_sum(f, x::AbstractArray) = f(x) -_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) +_childarray_sum(f, x) = Functors.isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) # utility functions diff --git a/src/utils.jl b/src/utils.jl index c1888829d4..dc222b2644 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -629,59 +629,6 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs)) [batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n] end -# Flattening models to weight vectors, and back - -function _restructure(m, xs) - i = 0 - m̄ = fmap(m) do x - x isa AbstractArray || return x - x = reshape(xs[i.+(1:length(x))], size(x)) - i += length(x) - return x - end - length(xs) == i || @warn "Expected $(i) params, got $(length(xs))" - return m̄ -end - -@adjoint function _restructure(m, xs) - m̄, numel = _restructure(m, xs), length(xs) - function _restructure_pullback(dm) - xs′ = destructure(dm)[1] - numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))" - return (nothing, xs′) - end - return m̄, _restructure_pullback -end - -""" - destructure(m) - -Flatten a model's parameters into a single weight vector. - - julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) - Chain(Dense(10, 5, σ), Dense(5, 2), softmax) - - julia> θ, re = destructure(m); - - julia> θ - 67-element Vector{Float32}: - -0.1407104 - ... - -The second return value `re` allows you to reconstruct the original network after making -modifications to the weight vector (for example, with a hypernetwork). - - julia> re(θ .* 2) - Chain(Dense(10, 5, σ), Dense(5, 2), softmax) -""" -function destructure(m) - xs = Zygote.Buffer([]) - fmap(m) do x - x isa AbstractArray && push!(xs, x) - return x - end - return vcat(vec.(copy(xs))...), p -> _restructure(m, p) -end # Other diff --git a/test/functor.jl b/test/functor.jl new file mode 100644 index 0000000000..acb09df3c8 --- /dev/null +++ b/test/functor.jl @@ -0,0 +1,165 @@ +using Flux: loadparams!, Zeros, destructure + +ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense + +dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout)) + +dm(bias) = Chain( + dl(3, 5, bias), + dl(5, 4, bias), + dl(4, 3, bias) + ) + +nobias(n) = Zeros() + +function testdense(m, bt) + @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) + @test l1.weight == l2.weight + @test l1.bias == l2.bias + @test typeof(l1.bias) === typeof(l2.bias) + end +end + +@testset "Params" begin + m = Dense(10, 5) + @test size.(params(m)) == [(5, 10), (5,)] + m = RNN(10, 5) + @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)] + + # Layer duplicated in same chain, params just once pls. + c = Chain(m, m) + @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)] + + # Self-referential array. Just want params, no stack overflow pls. + r = Any[nothing,m] + r[1] = r + @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] + + @testset "use params in gradient context" begin + m = Chain(Dense(3,2), Dense(2,2)) + ps = Flux.params(m) + gs = gradient(() -> sum(sum(p) for p in Flux.params(m)), ps) + for p in ps + @test gs[p] ≈ ones(size(p)) + end + + w1, w2 = rand(2), rand(2) + ps = Flux.params(w1, w2) + gs = gradient(() -> sum(sum(p) for p in Flux.params(w1, w2)), ps) + for p in ps + @test gs[p] ≈ ones(size(p)) + end + + m = Chain(Dense(3,2), Dense(2,2)) + g = gradient(m -> sum(params(m)[1]), m)[1] + @test g.layers[1].weight == ones(Float32, 2, 3) + + gs = gradient(() -> sum(params(m)[1]), params(m)) + @test gs[params(m)[1]] == ones(Float32, 2, 3) + + # Tests from https://github.com/FluxML/Flux.jl/pull/1614 + m = Dense(3, 2) + ps = Flux.params(m) + data = rand(Float32, 3, 5) + loss(m, x) = sum(m(x).^2) + + g1 = gradient(Flux.params(m)) do + loss(m, data) + end + g2 = gradient(Flux.params(m)) do + ps = Flux.params(m) # just creating params without using them + loss(m, data) + end + g3 = gradient(Flux.params(m)) do + ps = Flux.params(m) + loss(m, data) + sum(sum(p) for p in ps) + end + g4 = gradient(Flux.params(m)) do + loss(m, data) + sum(sum(p) for p in ps) + end + g5 = gradient(Flux.params(m)) do + sum(Flux.params(m)[1]) + sum(Flux.params(m)[2]) + end + g6 = gradient(Flux.params(m)) do + sum(ps[1]) + sum(ps[2]) + end + @test g2[m.weight] == g1[m.weight] + @test g3[m.weight] == g1[m.weight] .+ 1 + @test g4[m.weight] == g1[m.weight] .+ 1 + @test all(g5[m.weight] .== 1) + @test_broken all(g6[m.weight] .== 1) + end +end + + +@testset "Param remapping" begin + @testset "loadparams!" begin + pars(w, b) = [w, b] + + pars(w, b::Zeros) = [w, Flux.zeros32(size(w,1))] + pars(l) = pars(l.weight, l.bias) + pararray(m) = mapreduce(pars, vcat, m) + weights(m) = mapreduce(l -> [l.weight], vcat, m) + @testset "Bias type $bt" for bt in (Flux.zeros32, nobias) + m = dm(bt) + loadparams!(m, params(m)) + testdense(m, bt) + end + + @testset "$b1 to $b2" for (b1, b2, be) in ( + (Flux.zeros32, Flux.ones32, Flux.ones32), # Load ones as bias to a model with zeros as bias -> model gets ones as bias + (Flux.ones32, nobias, Flux.zeros32), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias + (nobias, Flux.ones32, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change + ) + m1 = dm(b1) + m2 = dm(b2) + loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2)) + testdense(m1, be) + end + end +end + +@testset "Destructure" begin + @testset "Bias type $bt" for bt in (zeros, nobias) + m = dm(bt) + p, re = destructure(m) + testdense(re(p), bt) + end + + @testset "restructure in gradient" begin + x = rand(Float32, 3, 1) + m = dm(zeros) + ∇m = gradient(m -> sum(m(x)), m)[1] + p, re = destructure(m) + ∇p = gradient(θ -> sum(re(θ)(x)), p)[1] + @test ∇p ≈ destructure(∇m)[1] rtol=1e-6 + end + + @testset "destructure with buffers" begin + p, re = destructure(BatchNorm(3)) + @test length(p) == 6 + + # https://github.com/FluxML/Flux.jl/issues/1727 + x = rand(Float32, 3, 4) + y, back = Flux.pullback(x, p) do x, p + vec(re(p)(x)) + end + @test_nowarn back(y) + b = back(y) + @test size(b[1]) == size(x) + @test size(b[2]) == size(p) + end +end + +@testset "Train and test mode" begin + mutable struct DummyLayer + testing::Bool + end + Flux.testmode!(m::DummyLayer, testing=true) = (m.testing = testing; m) + + c = Chain(DummyLayer(true)) + testmode!(c) + @test c[1].testing + trainmode!(c) + @test !c[1].testing +end diff --git a/test/runtests.jl b/test/runtests.jl index 781edb549d..fe75690dac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,10 @@ Random.seed!(0) include("utils.jl") end +@testset "Functor" begin + include("functor.jl") +end + @testset "Onehot" begin include("onehot.jl") end diff --git a/test/utils.jl b/test/utils.jl index 6b487e7854..6b21487f9b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -211,22 +211,6 @@ end end end -@testset "Params" begin - m = Dense(10, 5) - @test size.(params(m)) == [(5, 10), (5,)] - m = RNN(10, 5) - @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)] - - # Layer duplicated in same chain, params just once pls. - c = Chain(m, m) - @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)] - - # Self-referential array. Just want params, no stack overflow pls. - r = Any[nothing,m] - r[1] = r - @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] -end - @testset "Basic Stacking" begin x = randn(3,3) stacked = stack([x, x], 2) @@ -340,7 +324,6 @@ end @test stack(unstack(stacked_array, 1), 1) == stacked_array end - @testset "Batching" begin stacked_array=[ 8 9 3 5 9 6 6 9 From 0f24e95eeb64f8ee98f982bb1ee1b1b18fee35d9 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 15 Jan 2022 21:16:11 +0100 Subject: [PATCH 2/5] docs and simplify tests --- src/functor.jl | 128 +++++++++++++++++++++++++++--------------------- test/functor.jl | 36 +++++++------- 2 files changed, 90 insertions(+), 74 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 59144c4208..101f59b98d 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -66,18 +66,36 @@ end """ destructure(m) -Flatten a model's parameters into a single weight vector. - julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) - Chain(Dense(10, 5, σ), Dense(5, 2), softmax) - julia> θ, re = destructure(m); - julia> θ - 67-element Vector{Float32}: - -0.1407104 - ... -The second return value `re` allows you to reconstruct the original network after making + +Flatten a model's parameters into a single vector. + +```julia-repl +julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) +Chain( + Dense(10, 5, σ), # 55 parameters + Dense(5, 2), # 12 parameters + NNlib.softmax, +) # Total: 4 arrays, 67 parameters, 524 bytes + +julia> θ, re = Flux.destructure(m); + +julia> θ +67-element Vector{Float32}: + -0.1407104 + ... +``` + +The second returned value `re` allows you to reconstruct the original network after making modifications to the weight vector (for example, with a hypernetwork). - julia> re(θ .* 2) - Chain(Dense(10, 5, σ), Dense(5, 2), softmax) + +```julia-repl +julia> re(θ .* 2) +Chain( + Dense(10, 5, σ), # 55 parameters + Dense(5, 2), # 12 parameters + NNlib.softmax, +) # Total: 4 arrays, 67 parameters, 524 bytes. +``` """ function destructure(m) xs = Zygote.Buffer([]) @@ -86,14 +104,23 @@ function destructure(m) end function collect_params!(xs, m) + # Filtering function for the traversal of the functor. + # We walk from node x to children c only if c is one of the trainable children of x. filter = (x, c) -> any(y -> c === y, trainable(x)) + + # Get the walk function corrisponding to the given filter. walk = filtered_walk(filter) + fmap(m; walk) do x x isa AbstractArray{<:Number} && push!(xs, x) return x end end +""" +Return a `walk` function to be passed to `fmap` that applies the function +`f` to be mapped only on the children selected by `filter`. +""" function filtered_walk(filter) seen = IdSet() @@ -112,17 +139,39 @@ function filtered_walk(filter) end + """ - params(m...) + params(m...) + +Collect trainable parameters from the input model(s) `m` into a [`Zygote.Params`](@ref) object. + +Only the parameters that can be reached by recursion on the [`trainable`](@ref) children of the tree with root `m` are collected. +If `trainable` is not defined for a specific node's type in `m`, it will fall back to [`Functor.@functor`](@ref). + +Users are recommended to define `trainable` for their custom types to control the trainable parameters' selection. + +# Examples +```jldoctest +julia> params(Chain(Dense(ones(2,3)), softmax)) # unpacks Flux models +Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]]) -Collect trainable parameters (a.k.a. numerical arrays) -from the input model(s) `m` into a [`Zygote.Params`](@ref) object. +julia> bn = BatchNorm(2, relu) +BatchNorm(2, relu) # 4 parameters, plus 4 non-trainable -Only the parameters that can be reached by recursion -on the [`trainable`](@ref) children of -the tree with root `m` are collected. +julia> params(bn) # only the trainable parameters +Params([Float32[0.0, 0.0], Float32[1.0, 1.0]]) -# Usage +julia> params([1, 2, 3], [4]) # one or more arrays of numbers +Params([[1, 2, 3], [4]]) + +julia> params([[1, 2, 3], [4]]) # unpacks array of arrays +Params([[1, 2, 3], [4]]) + +julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin)) # ignores scalars, unpacks NamedTuples +Params([[2 2], [3, 3, 3]]) +``` + +A `Params` object can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function. ```julia-repl julia> m = Dense(ones(2, 3), zeros(2)) @@ -146,7 +195,11 @@ julia> gs[m.weight] 2.0 2.0 2.0 ``` """ -function params end +function params(m...) + ps = Params() + params!(ps, m) + return ps +end ## TODO This causes some test regressions. Why? # function params(m...) @@ -165,43 +218,6 @@ function params!(p::Params, x, seen = IdSet()) end end -""" - params(model) - params(layers...) - -Given a model or specific layers from a model, create a `Params` object pointing to its trainable parameters. - -This can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function. - -The behaviour of `params` on custom types can be customized using [`Functor.@functor`](@ref) or [`Flux.trainable`](@ref). - -# Examples -```jldoctest -julia> params(Chain(Dense(ones(2,3)), softmax)) # unpacks Flux models -Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]]) - -julia> bn = BatchNorm(2, relu) -BatchNorm(2, relu) # 4 parameters, plus 4 non-trainable - -julia> params(bn) # only the trainable parameters -Params([Float32[0.0, 0.0], Float32[1.0, 1.0]]) - -julia> params([1, 2, 3], [4]) # one or more arrays of numbers -Params([[1, 2, 3], [4]]) - -julia> params([[1, 2, 3], [4]]) # unpacks array of arrays -Params([[1, 2, 3], [4]]) - -julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin)) # ignores scalars, unpacks NamedTuples -Params([[2 2], [3, 3, 3]]) -``` -""" -function params(m...) - ps = Params() - params!(ps, m) - return ps -end - function loadparams!(m, xs) for (p, x) in zip(params(m), xs) size(p) == size(x) || diff --git a/test/functor.jl b/test/functor.jl index acb09df3c8..a5a4d0e21c 100644 --- a/test/functor.jl +++ b/test/functor.jl @@ -1,19 +1,19 @@ using Flux: loadparams!, Zeros, destructure -ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense +function build_test_chain(fbias) + ls(dims...) = reshape(collect(Float32, 1:prod(dims)), reverse(dims)...) -dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout)) - -dm(bias) = Chain( - dl(3, 5, bias), - dl(5, 4, bias), - dl(4, 3, bias) - ) + Chain( + Dense(ls(3, 5), fbias(5)), + Dense(ls(5, 4), fbias(4)), + Dense(ls(4, 3), fbias(3)) + ) +end nobias(n) = Zeros() -function testdense(m, bt) - @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) +function test_chains_equal(m1, m2) + @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m1, m2)) @test l1.weight == l2.weight @test l1.bias == l2.bias @test typeof(l1.bias) === typeof(l2.bias) @@ -101,9 +101,9 @@ end pararray(m) = mapreduce(pars, vcat, m) weights(m) = mapreduce(l -> [l.weight], vcat, m) @testset "Bias type $bt" for bt in (Flux.zeros32, nobias) - m = dm(bt) + m = build_test_chain(bt) loadparams!(m, params(m)) - testdense(m, bt) + test_chains_equal(m, build_test_chain(bt)) end @testset "$b1 to $b2" for (b1, b2, be) in ( @@ -111,24 +111,24 @@ end (Flux.ones32, nobias, Flux.zeros32), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias (nobias, Flux.ones32, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change ) - m1 = dm(b1) - m2 = dm(b2) + m1 = build_test_chain(b1) + m2 = build_test_chain(b2) loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2)) - testdense(m1, be) + test_chains_equal(m1, build_test_chain(be)) end end end @testset "Destructure" begin @testset "Bias type $bt" for bt in (zeros, nobias) - m = dm(bt) + m = build_test_chain(bt) p, re = destructure(m) - testdense(re(p), bt) + test_chains_equal(re(p), build_test_chain(bt)) end @testset "restructure in gradient" begin x = rand(Float32, 3, 1) - m = dm(zeros) + m = build_test_chain(zeros) ∇m = gradient(m -> sum(m(x)), m)[1] p, re = destructure(m) ∇p = gradient(θ -> sum(re(θ)(x)), p)[1] From 3a8eed4010689c2b633c89ad4e9a4cd6b0db31ad Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 15 Jan 2022 21:47:07 +0100 Subject: [PATCH 3/5] address review comments --- src/functor.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 101f59b98d..06196b4bbc 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -42,8 +42,8 @@ trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, function _restructure(m, xs) i = 0 - filter = (x, c) -> any(y -> c === y, trainable(x)) - walk = filtered_walk(filter) + cond = (x, c) -> any(y -> c === y, trainable(x)) + walk = filtered_walk(cond) m̄ = fmap(m; walk) do x x isa AbstractArray{<:Number} || return x x = reshape(xs[i .+ (1:length(x))], size(x)) @@ -106,10 +106,10 @@ end function collect_params!(xs, m) # Filtering function for the traversal of the functor. # We walk from node x to children c only if c is one of the trainable children of x. - filter = (x, c) -> any(y -> c === y, trainable(x)) + cond = (x, c) -> any(y -> c === y, trainable(x)) - # Get the walk function corrisponding to the given filter. - walk = filtered_walk(filter) + # Get the walk function corrisponding to the given condition. + walk = filtered_walk(cond) fmap(m; walk) do x x isa AbstractArray{<:Number} && push!(xs, x) @@ -121,7 +121,7 @@ end Return a `walk` function to be passed to `fmap` that applies the function `f` to be mapped only on the children selected by `filter`. """ -function filtered_walk(filter) +function filtered_walk(cond::Function) seen = IdSet() function walk(f, x) From deed8052ae07ed566dc94dfda0c30f33042db0ea Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 15 Jan 2022 21:56:37 +0100 Subject: [PATCH 4/5] destructure docstring --- src/functor.jl | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 06196b4bbc..3559045b06 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -67,21 +67,24 @@ end """ destructure(m) -Flatten a model's parameters into a single vector. +Flatten a model's trainable parameters into a single vector. ```julia-repl -julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) +julia> m = Chain(Dense(10, 5, relu), BatchNorm(5), Dense(5, 2), softmax) Chain( - Dense(10, 5, σ), # 55 parameters + Dense(10, 5, relu), # 55 parameters + BatchNorm(5), # 10 parameters, plus 10 Dense(5, 2), # 12 parameters NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes +) # Total: 6 trainable arrays, 77 parameters, + # plus 2 non-trainable, 10 parameters, summarysize 836 bytes. julia> θ, re = Flux.destructure(m); julia> θ -67-element Vector{Float32}: - -0.1407104 +77-element Vector{Float32}: + -0.23847938 + -0.3672024 ... ``` @@ -89,12 +92,14 @@ The second returned value `re` allows you to reconstruct the original network af modifications to the weight vector (for example, with a hypernetwork). ```julia-repl -julia> re(θ .* 2) +julia> re(θ) Chain( - Dense(10, 5, σ), # 55 parameters + Dense(10, 5, relu), # 55 parameters + BatchNorm(5), # 10 parameters, plus 10 Dense(5, 2), # 12 parameters NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes. +) # Total: 6 trainable arrays, 77 parameters, + # plus 2 non-trainable, 10 parameters, summarysize 836 bytes. ``` """ function destructure(m) @@ -130,7 +135,7 @@ function filtered_walk(cond::Function) children, reconstruct = functor(x) mappedchildren = map(children) do c - filter(x, c) ? f(c) : c + cond(x, c) ? f(c) : c end reconstruct(mappedchildren) end From f179c4d2f8992b3aef8c2b3a3568c20fd6239df4 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 15 Jan 2022 22:05:01 +0100 Subject: [PATCH 5/5] more docs improve --- src/functor.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/functor.jl b/src/functor.jl index 3559045b06..ecbb3c04f7 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -101,6 +101,9 @@ Chain( ) # Total: 6 trainable arrays, 77 parameters, # plus 2 non-trainable, 10 parameters, summarysize 836 bytes. ``` + +Only numerical arrays are collected by `destructe`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers) +it will be collected only once. """ function destructure(m) xs = Zygote.Buffer([])