Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Zygote, MacroTools, Juno, Reexport
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd
using Functors: Functors, @functor, functor, fmap
export gradient

export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
Expand Down
165 changes: 153 additions & 12 deletions src/functor.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import Adapt: adapt, adapt_storage
using LinearAlgebra: Cholesky
using Zygote: IdSet
import Functors: Functors, @functor, functor, fmap, isleaf
using SparseArrays: AbstractSparseArray

trainable(m) = functor(m)[1]
Expand Down Expand Up @@ -38,25 +37,126 @@ Possible values include:
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)

params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)

function params!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
# Flattening models to weight vectors, and back

function _restructure(m, xs)
i = 0
cond = (x, c) -> any(y -> c === y, trainable(x))
walk = filtered_walk(cond)
m̄ = fmap(m; walk) do x
x isa AbstractArray{<:Number} || return x
x = reshape(xs[i .+ (1:length(x))], size(x))
i += length(x)
return x
end
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
return m̄
end

@adjoint function _restructure(m, xs)
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gradient can easily be wrong, it looks for duplicates in the gradient which can come from e.g. adding two parameters x + y. It is completely unaware of duplicates in the original.

Demonstration:
#1826 (comment)

So at very least, we must (1) disable the removal of duplicates from destructure used here, and (2) throw an error if you try to use this adjoint when the original model had any gradients.

Or, failing that, we should remove it from v0.13 until someone can write a version which actually works.

numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
return (nothing, xs′)
end
return m̄, _restructure_pullback
end

"""
destructure(m)

Flatten a model's trainable parameters into a single vector.

```julia-repl
julia> m = Chain(Dense(10, 5, relu), BatchNorm(5), Dense(5, 2), softmax)
Chain(
Dense(10, 5, relu), # 55 parameters
BatchNorm(5), # 10 parameters, plus 10
Dense(5, 2), # 12 parameters
NNlib.softmax,
) # Total: 6 trainable arrays, 77 parameters,
# plus 2 non-trainable, 10 parameters, summarysize 836 bytes.

julia> θ, re = Flux.destructure(m);

julia> θ
77-element Vector{Float32}:
-0.23847938
-0.3672024
...
```

The second returned value `re` allows you to reconstruct the original network after making
modifications to the weight vector (for example, with a hypernetwork).

```julia-repl
julia> re(θ)
Chain(
Dense(10, 5, relu), # 55 parameters
BatchNorm(5), # 10 parameters, plus 10
Dense(5, 2), # 12 parameters
NNlib.softmax,
) # Total: 6 trainable arrays, 77 parameters,
# plus 2 non-trainable, 10 parameters, summarysize 836 bytes.
```

Only numerical arrays are collected by `destructe`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Only numerical arrays are collected by `destructe`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers)
Only numerical arrays are collected by `destructure`. Moreover, if the same array is nested multiple times in the same model (e.g. shared by some layers),

it will be collected only once.
"""
function destructure(m)
xs = Zygote.Buffer([])
collect_params!(xs, m)
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this copy? (And splat?)

And, how easy would it be to avoid Buffer somehow, to make this ready for not using Zygote?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With

function destructure(m)
  xs = AbstractArray[]
  collect_params!(xs, m)
  return vcat(vec.(xs)...), p -> _restructure(m, p)
end

Flux's tests still pass. I still have to test the interaction with DiffEqFlux, NeuralPDE, let's see

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DiffEqFlux tests pass with both Buffer() and AbstractArray[].

@ChrisRackauckas you see any particular reason to keep Buffer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a reason to use Buffer here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that's for some potential higher order AD issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or just first-order AD. AFAICT Flux's current test suite never tests the gradient of destructure (only restructuring) 🙈...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙈 is the Flux motto, really.

end

function collect_params!(xs, m)
# Filtering function for the traversal of the functor.
# We walk from node x to children c only if c is one of the trainable children of x.
cond = (x, c) -> any(y -> c === y, trainable(x))

# Get the walk function corrisponding to the given condition.
walk = filtered_walk(cond)

fmap(m; walk) do x
x isa AbstractArray{<:Number} && push!(xs, x)
return x
end
end

"""
Return a `walk` function to be passed to `fmap` that applies the function
`f` to be mapped only on the children selected by `filter`.
"""
params(model)
params(layers...)
function filtered_walk(cond::Function)
seen = IdSet()

Given a model or specific layers from a model, create a `Params` object pointing to its trainable parameters.
function walk(f, x)
x in seen && return x
push!(seen, x)

This can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function.
children, reconstruct = functor(x)
mappedchildren = map(children) do c
cond(x, c) ? f(c) : c
end
reconstruct(mappedchildren)
end

return walk
end



"""
params(m...)

Collect trainable parameters from the input model(s) `m` into a [`Zygote.Params`](@ref) object.

Only the parameters that can be reached by recursion on the [`trainable`](@ref) children of the tree with root `m` are collected.
If `trainable` is not defined for a specific node's type in `m`, it will fall back to [`Functor.@functor`](@ref).

The behaviour of `params` on custom types can be customized using [`Functor.@functor`](@ref) or [`Flux.trainable`](@ref).
Users are recommended to define `trainable` for their custom types to control the trainable parameters' selection.

# Examples
```jldoctest
Expand All @@ -78,13 +178,54 @@ Params([[1, 2, 3], [4]])
julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin)) # ignores scalars, unpacks NamedTuples
Params([[2 2], [3, 3, 3]])
```

A `Params` object can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function.

```julia-repl
julia> m = Dense(ones(2, 3), zeros(2))
Dense(3, 2) # 8 parameters

julia> ps = Flux.params(m)
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])

julia> x = ones(3)
3-element Vector{Float64}:
1.0
1.0
1.0

julia> gs = gradient(() -> sum(2 .* m(x)), ps)
Grads(...)

julia> gs[m.weight]
2×3 Matrix{Float64}:
2.0 2.0 2.0
2.0 2.0 2.0
```
"""
function params(m...)
ps = Params()
params!(ps, m)
return ps
end

## TODO This causes some test regressions. Why?
# function params(m...)
# ps = Params()
# collect_params!(ps, m)
# return ps
# end

params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)

function params!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
end

function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
Expand Down
2 changes: 1 addition & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
Functors.functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
Expand Down
4 changes: 2 additions & 2 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end

_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
_show_leaflike(::Diagonal) = true # appears inside LayerNorm
Expand Down Expand Up @@ -97,7 +97,7 @@ function _big_finale(io::IO, m)
end

_childarray_sum(f, x::AbstractArray) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))
_childarray_sum(f, x) = Functors.isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))

# utility functions

Expand Down
53 changes: 0 additions & 53 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,59 +629,6 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
end

# Flattening models to weight vectors, and back

function _restructure(m, xs)
i = 0
m̄ = fmap(m) do x
x isa AbstractArray || return x
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
return m̄
end

@adjoint function _restructure(m, xs)
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
return (nothing, xs′)
end
return m̄, _restructure_pullback
end

"""
destructure(m)

Flatten a model's parameters into a single weight vector.

julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)

julia> θ, re = destructure(m);

julia> θ
67-element Vector{Float32}:
-0.1407104
...

The second return value `re` allows you to reconstruct the original network after making
modifications to the weight vector (for example, with a hypernetwork).

julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
"""
function destructure(m)
xs = Zygote.Buffer([])
fmap(m) do x
x isa AbstractArray && push!(xs, x)
return x
end
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
end

# Other

Expand Down
Loading