-
-
Notifications
You must be signed in to change notification settings - Fork 617
Have destructure return only trainable params #1742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4dc70b5
0f24e95
3a8eed4
deed805
f179c4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,7 +1,6 @@ | ||||||
| import Adapt: adapt, adapt_storage | ||||||
| using LinearAlgebra: Cholesky | ||||||
| using Zygote: IdSet | ||||||
| import Functors: Functors, @functor, functor, fmap, isleaf | ||||||
| using SparseArrays: AbstractSparseArray | ||||||
|
|
||||||
| trainable(m) = functor(m)[1] | ||||||
|
|
@@ -38,25 +37,126 @@ Possible values include: | |||||
| """ | ||||||
| trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) | ||||||
|
|
||||||
| params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) | ||||||
|
|
||||||
| function params!(p::Params, x, seen = IdSet()) | ||||||
| x in seen && return | ||||||
| push!(seen, x) | ||||||
| for child in trainable(x) | ||||||
| params!(p, child, seen) | ||||||
| # Flattening models to weight vectors, and back | ||||||
|
|
||||||
| function _restructure(m, xs) | ||||||
| i = 0 | ||||||
| cond = (x, c) -> any(y -> c === y, trainable(x)) | ||||||
| walk = filtered_walk(cond) | ||||||
| m̄ = fmap(m; walk) do x | ||||||
| x isa AbstractArray{<:Number} || return x | ||||||
| x = reshape(xs[i .+ (1:length(x))], size(x)) | ||||||
| i += length(x) | ||||||
| return x | ||||||
| end | ||||||
| length(xs) == i || @warn "Expected $(i) params, got $(length(xs))" | ||||||
| return m̄ | ||||||
| end | ||||||
|
|
||||||
| @adjoint function _restructure(m, xs) | ||||||
| m̄, numel = _restructure(m, xs), length(xs) | ||||||
| function _restructure_pullback(dm) | ||||||
| xs′ = destructure(dm)[1] | ||||||
| numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))" | ||||||
| return (nothing, xs′) | ||||||
| end | ||||||
| return m̄, _restructure_pullback | ||||||
| end | ||||||
|
|
||||||
| """ | ||||||
| destructure(m) | ||||||
|
|
||||||
| Flatten a model's trainable parameters into a single vector. | ||||||
|
|
||||||
| ```julia-repl | ||||||
| julia> m = Chain(Dense(10, 5, relu), BatchNorm(5), Dense(5, 2), softmax) | ||||||
| Chain( | ||||||
| Dense(10, 5, relu), # 55 parameters | ||||||
| BatchNorm(5), # 10 parameters, plus 10 | ||||||
| Dense(5, 2), # 12 parameters | ||||||
| NNlib.softmax, | ||||||
| ) # Total: 6 trainable arrays, 77 parameters, | ||||||
| # plus 2 non-trainable, 10 parameters, summarysize 836 bytes. | ||||||
|
|
||||||
| julia> θ, re = Flux.destructure(m); | ||||||
|
|
||||||
| julia> θ | ||||||
| 77-element Vector{Float32}: | ||||||
| -0.23847938 | ||||||
| -0.3672024 | ||||||
| ... | ||||||
| ``` | ||||||
|
|
||||||
| The second returned value `re` allows you to reconstruct the original network after making | ||||||
| modifications to the weight vector (for example, with a hypernetwork). | ||||||
|
|
||||||
| ```julia-repl | ||||||
| julia> re(θ) | ||||||
| Chain( | ||||||
| Dense(10, 5, relu), # 55 parameters | ||||||
| BatchNorm(5), # 10 parameters, plus 10 | ||||||
| Dense(5, 2), # 12 parameters | ||||||
| NNlib.softmax, | ||||||
| ) # Total: 6 trainable arrays, 77 parameters, | ||||||
| # plus 2 non-trainable, 10 parameters, summarysize 836 bytes. | ||||||
| ``` | ||||||
|
|
||||||
| Only numerical arrays are collected by `destructe`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| it will be collected only once. | ||||||
| """ | ||||||
| function destructure(m) | ||||||
| xs = Zygote.Buffer([]) | ||||||
| collect_params!(xs, m) | ||||||
| return vcat(vec.(copy(xs))...), p -> _restructure(m, p) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this copy? (And splat?) And, how easy would it be to avoid Buffer somehow, to make this ready for not using Zygote?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With function destructure(m)
xs = AbstractArray[]
collect_params!(xs, m)
return vcat(vec.(xs)...), p -> _restructure(m, p)
endFlux's tests still pass. I still have to test the interaction with DiffEqFlux, NeuralPDE, let's see
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DiffEqFlux tests pass with both @ChrisRackauckas you see any particular reason to keep Buffer?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a reason to use Buffer here.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that's for some potential higher order AD issue?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or just first-order AD. AFAICT Flux's current test suite never tests the gradient of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🙈 is the Flux motto, really. |
||||||
| end | ||||||
|
|
||||||
| function collect_params!(xs, m) | ||||||
| # Filtering function for the traversal of the functor. | ||||||
| # We walk from node x to children c only if c is one of the trainable children of x. | ||||||
| cond = (x, c) -> any(y -> c === y, trainable(x)) | ||||||
|
|
||||||
| # Get the walk function corrisponding to the given condition. | ||||||
| walk = filtered_walk(cond) | ||||||
|
|
||||||
| fmap(m; walk) do x | ||||||
| x isa AbstractArray{<:Number} && push!(xs, x) | ||||||
| return x | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
| """ | ||||||
| Return a `walk` function to be passed to `fmap` that applies the function | ||||||
| `f` to be mapped only on the children selected by `filter`. | ||||||
| """ | ||||||
| params(model) | ||||||
| params(layers...) | ||||||
| function filtered_walk(cond::Function) | ||||||
| seen = IdSet() | ||||||
|
|
||||||
| Given a model or specific layers from a model, create a `Params` object pointing to its trainable parameters. | ||||||
| function walk(f, x) | ||||||
| x in seen && return x | ||||||
| push!(seen, x) | ||||||
|
|
||||||
| This can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function. | ||||||
| children, reconstruct = functor(x) | ||||||
| mappedchildren = map(children) do c | ||||||
| cond(x, c) ? f(c) : c | ||||||
| end | ||||||
| reconstruct(mappedchildren) | ||||||
| end | ||||||
|
|
||||||
| return walk | ||||||
| end | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| """ | ||||||
| params(m...) | ||||||
|
|
||||||
| Collect trainable parameters from the input model(s) `m` into a [`Zygote.Params`](@ref) object. | ||||||
|
|
||||||
| Only the parameters that can be reached by recursion on the [`trainable`](@ref) children of the tree with root `m` are collected. | ||||||
| If `trainable` is not defined for a specific node's type in `m`, it will fall back to [`Functor.@functor`](@ref). | ||||||
|
|
||||||
| The behaviour of `params` on custom types can be customized using [`Functor.@functor`](@ref) or [`Flux.trainable`](@ref). | ||||||
| Users are recommended to define `trainable` for their custom types to control the trainable parameters' selection. | ||||||
|
|
||||||
| # Examples | ||||||
| ```jldoctest | ||||||
|
|
@@ -78,13 +178,54 @@ Params([[1, 2, 3], [4]]) | |||||
| julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin)) # ignores scalars, unpacks NamedTuples | ||||||
| Params([[2 2], [3, 3, 3]]) | ||||||
| ``` | ||||||
|
|
||||||
| A `Params` object can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function. | ||||||
|
|
||||||
| ```julia-repl | ||||||
| julia> m = Dense(ones(2, 3), zeros(2)) | ||||||
| Dense(3, 2) # 8 parameters | ||||||
|
|
||||||
| julia> ps = Flux.params(m) | ||||||
| Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]]) | ||||||
|
|
||||||
| julia> x = ones(3) | ||||||
| 3-element Vector{Float64}: | ||||||
| 1.0 | ||||||
| 1.0 | ||||||
| 1.0 | ||||||
|
|
||||||
| julia> gs = gradient(() -> sum(2 .* m(x)), ps) | ||||||
| Grads(...) | ||||||
|
|
||||||
| julia> gs[m.weight] | ||||||
| 2×3 Matrix{Float64}: | ||||||
| 2.0 2.0 2.0 | ||||||
| 2.0 2.0 2.0 | ||||||
| ``` | ||||||
| """ | ||||||
| function params(m...) | ||||||
| ps = Params() | ||||||
| params!(ps, m) | ||||||
| return ps | ||||||
| end | ||||||
|
|
||||||
| ## TODO This causes some test regressions. Why? | ||||||
| # function params(m...) | ||||||
| # ps = Params() | ||||||
| # collect_params!(ps, m) | ||||||
| # return ps | ||||||
| # end | ||||||
|
|
||||||
| params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) | ||||||
|
|
||||||
| function params!(p::Params, x, seen = IdSet()) | ||||||
| x in seen && return | ||||||
| push!(seen, x) | ||||||
| for child in trainable(x) | ||||||
| params!(p, child, seen) | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
| function loadparams!(m, xs) | ||||||
| for (p, x) in zip(params(m), xs) | ||||||
| size(p) == size(x) || | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gradient can easily be wrong, it looks for duplicates in the gradient which can come from e.g. adding two parameters
x + y. It is completely unaware of duplicates in the original.Demonstration:
#1826 (comment)
So at very least, we must (1) disable the removal of duplicates from destructure used here, and (2) throw an error if you try to use this adjoint when the original model had any gradients.
Or, failing that, we should remove it from v0.13 until someone can write a version which actually works.