Skip to content

Commit 4dc70b5

Browse files
destructure returns only trainable params
1 parent 79dbbd6 commit 4dc70b5

File tree

8 files changed

+291
-74
lines changed

8 files changed

+291
-74
lines changed

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Zygote, MacroTools, Juno, Reexport
88
using MacroTools: @forward
99
@reexport using NNlib
1010
using Zygote: Params, @adjoint, gradient, pullback, @nograd
11+
using Functors: Functors, @functor, functor, fmap
1112
export gradient
1213

1314
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,

src/functor.jl

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import Adapt: adapt, adapt_storage
22
using LinearAlgebra: Cholesky
33
using Zygote: IdSet
4-
import Functors: Functors, @functor, functor, fmap, isleaf
54
using SparseArrays: AbstractSparseArray
65

76
trainable(m) = functor(m)[1]
@@ -38,6 +37,124 @@ Possible values include:
3837
"""
3938
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
4039

40+
41+
# Flattening models to weight vectors, and back
42+
43+
function _restructure(m, xs)
44+
i = 0
45+
filter = (x, c) -> any(y -> c === y, trainable(x))
46+
walk = filtered_walk(filter)
47+
= fmap(m; walk) do x
48+
x isa AbstractArray{<:Number} || return x
49+
x = reshape(xs[i .+ (1:length(x))], size(x))
50+
i += length(x)
51+
return x
52+
end
53+
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
54+
return
55+
end
56+
57+
@adjoint function _restructure(m, xs)
58+
m̄, numel = _restructure(m, xs), length(xs)
59+
function _restructure_pullback(dm)
60+
xs′ = destructure(dm)[1]
61+
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
62+
return (nothing, xs′)
63+
end
64+
return m̄, _restructure_pullback
65+
end
66+
67+
"""
68+
destructure(m)
69+
Flatten a model's parameters into a single weight vector.
70+
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
71+
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
72+
julia> θ, re = destructure(m);
73+
julia> θ
74+
67-element Vector{Float32}:
75+
-0.1407104
76+
...
77+
The second return value `re` allows you to reconstruct the original network after making
78+
modifications to the weight vector (for example, with a hypernetwork).
79+
julia> re(θ .* 2)
80+
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
81+
"""
82+
function destructure(m)
83+
xs = Zygote.Buffer([])
84+
collect_params!(xs, m)
85+
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
86+
end
87+
88+
function collect_params!(xs, m)
89+
filter = (x, c) -> any(y -> c === y, trainable(x))
90+
walk = filtered_walk(filter)
91+
fmap(m; walk) do x
92+
x isa AbstractArray{<:Number} && push!(xs, x)
93+
return x
94+
end
95+
end
96+
97+
function filtered_walk(filter)
98+
seen = IdSet()
99+
100+
function walk(f, x)
101+
x in seen && return x
102+
push!(seen, x)
103+
104+
children, reconstruct = functor(x)
105+
mappedchildren = map(children) do c
106+
filter(x, c) ? f(c) : c
107+
end
108+
reconstruct(mappedchildren)
109+
end
110+
111+
return walk
112+
end
113+
114+
115+
"""
116+
params(m...)
117+
118+
Collect trainable parameters (a.k.a. numerical arrays)
119+
from the input model(s) `m` into a [`Zygote.Params`](@ref) object.
120+
121+
Only the parameters that can be reached by recursion
122+
on the [`trainable`](@ref) children of
123+
the tree with root `m` are collected.
124+
125+
# Usage
126+
127+
```julia-repl
128+
julia> m = Dense(ones(2, 3), zeros(2))
129+
Dense(3, 2) # 8 parameters
130+
131+
julia> ps = Flux.params(m)
132+
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
133+
134+
julia> x = ones(3)
135+
3-element Vector{Float64}:
136+
1.0
137+
1.0
138+
1.0
139+
140+
julia> gs = gradient(() -> sum(2 .* m(x)), ps)
141+
Grads(...)
142+
143+
julia> gs[m.weight]
144+
2×3 Matrix{Float64}:
145+
2.0 2.0 2.0
146+
2.0 2.0 2.0
147+
```
148+
"""
149+
function params end
150+
151+
## TODO This causes some test regressions. Why?
152+
# function params(m...)
153+
# ps = Params()
154+
# collect_params!(ps, m)
155+
# return ps
156+
# end
157+
41158
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
42159

43160
function params!(p::Params, x, seen = IdSet())

src/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ end
4141
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
4242
Base.iterate, Base.lastindex, Base.keys
4343

44-
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
44+
Functors.functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
4545

4646
applychain(::Tuple{}, x) = x
4747
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))

src/layers/show.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
4343
end
4444
end
4545

46-
_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
46+
_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
4747
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
4848
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
4949
_show_leaflike(::Diagonal) = true # appears inside LayerNorm
@@ -97,7 +97,7 @@ function _big_finale(io::IO, m)
9797
end
9898

9999
_childarray_sum(f, x::AbstractArray) = f(x)
100-
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))
100+
_childarray_sum(f, x) = Functors.isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))
101101

102102
# utility functions
103103

src/utils.jl

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -629,59 +629,6 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
629629
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
630630
end
631631

632-
# Flattening models to weight vectors, and back
633-
634-
function _restructure(m, xs)
635-
i = 0
636-
= fmap(m) do x
637-
x isa AbstractArray || return x
638-
x = reshape(xs[i.+(1:length(x))], size(x))
639-
i += length(x)
640-
return x
641-
end
642-
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
643-
return
644-
end
645-
646-
@adjoint function _restructure(m, xs)
647-
m̄, numel = _restructure(m, xs), length(xs)
648-
function _restructure_pullback(dm)
649-
xs′ = destructure(dm)[1]
650-
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
651-
return (nothing, xs′)
652-
end
653-
return m̄, _restructure_pullback
654-
end
655-
656-
"""
657-
destructure(m)
658-
659-
Flatten a model's parameters into a single weight vector.
660-
661-
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
662-
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
663-
664-
julia> θ, re = destructure(m);
665-
666-
julia> θ
667-
67-element Vector{Float32}:
668-
-0.1407104
669-
...
670-
671-
The second return value `re` allows you to reconstruct the original network after making
672-
modifications to the weight vector (for example, with a hypernetwork).
673-
674-
julia> re(θ .* 2)
675-
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
676-
"""
677-
function destructure(m)
678-
xs = Zygote.Buffer([])
679-
fmap(m) do x
680-
x isa AbstractArray && push!(xs, x)
681-
return x
682-
end
683-
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
684-
end
685632

686633
# Other
687634

test/functor.jl

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
using Flux: loadparams!, Zeros, destructure
2+
3+
ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense
4+
5+
dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout))
6+
7+
dm(bias) = Chain(
8+
dl(3, 5, bias),
9+
dl(5, 4, bias),
10+
dl(4, 3, bias)
11+
)
12+
13+
nobias(n) = Zeros()
14+
15+
function testdense(m, bt)
16+
@testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
17+
@test l1.weight == l2.weight
18+
@test l1.bias == l2.bias
19+
@test typeof(l1.bias) === typeof(l2.bias)
20+
end
21+
end
22+
23+
@testset "Params" begin
24+
m = Dense(10, 5)
25+
@test size.(params(m)) == [(5, 10), (5,)]
26+
m = RNN(10, 5)
27+
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)]
28+
29+
# Layer duplicated in same chain, params just once pls.
30+
c = Chain(m, m)
31+
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)]
32+
33+
# Self-referential array. Just want params, no stack overflow pls.
34+
r = Any[nothing,m]
35+
r[1] = r
36+
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)]
37+
38+
@testset "use params in gradient context" begin
39+
m = Chain(Dense(3,2), Dense(2,2))
40+
ps = Flux.params(m)
41+
gs = gradient(() -> sum(sum(p) for p in Flux.params(m)), ps)
42+
for p in ps
43+
@test gs[p] ones(size(p))
44+
end
45+
46+
w1, w2 = rand(2), rand(2)
47+
ps = Flux.params(w1, w2)
48+
gs = gradient(() -> sum(sum(p) for p in Flux.params(w1, w2)), ps)
49+
for p in ps
50+
@test gs[p] ones(size(p))
51+
end
52+
53+
m = Chain(Dense(3,2), Dense(2,2))
54+
g = gradient(m -> sum(params(m)[1]), m)[1]
55+
@test g.layers[1].weight == ones(Float32, 2, 3)
56+
57+
gs = gradient(() -> sum(params(m)[1]), params(m))
58+
@test gs[params(m)[1]] == ones(Float32, 2, 3)
59+
60+
# Tests from https://github.com/FluxML/Flux.jl/pull/1614
61+
m = Dense(3, 2)
62+
ps = Flux.params(m)
63+
data = rand(Float32, 3, 5)
64+
loss(m, x) = sum(m(x).^2)
65+
66+
g1 = gradient(Flux.params(m)) do
67+
loss(m, data)
68+
end
69+
g2 = gradient(Flux.params(m)) do
70+
ps = Flux.params(m) # just creating params without using them
71+
loss(m, data)
72+
end
73+
g3 = gradient(Flux.params(m)) do
74+
ps = Flux.params(m)
75+
loss(m, data) + sum(sum(p) for p in ps)
76+
end
77+
g4 = gradient(Flux.params(m)) do
78+
loss(m, data) + sum(sum(p) for p in ps)
79+
end
80+
g5 = gradient(Flux.params(m)) do
81+
sum(Flux.params(m)[1]) + sum(Flux.params(m)[2])
82+
end
83+
g6 = gradient(Flux.params(m)) do
84+
sum(ps[1]) + sum(ps[2])
85+
end
86+
@test g2[m.weight] == g1[m.weight]
87+
@test g3[m.weight] == g1[m.weight] .+ 1
88+
@test g4[m.weight] == g1[m.weight] .+ 1
89+
@test all(g5[m.weight] .== 1)
90+
@test_broken all(g6[m.weight] .== 1)
91+
end
92+
end
93+
94+
95+
@testset "Param remapping" begin
96+
@testset "loadparams!" begin
97+
pars(w, b) = [w, b]
98+
99+
pars(w, b::Zeros) = [w, Flux.zeros32(size(w,1))]
100+
pars(l) = pars(l.weight, l.bias)
101+
pararray(m) = mapreduce(pars, vcat, m)
102+
weights(m) = mapreduce(l -> [l.weight], vcat, m)
103+
@testset "Bias type $bt" for bt in (Flux.zeros32, nobias)
104+
m = dm(bt)
105+
loadparams!(m, params(m))
106+
testdense(m, bt)
107+
end
108+
109+
@testset "$b1 to $b2" for (b1, b2, be) in (
110+
(Flux.zeros32, Flux.ones32, Flux.ones32), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
111+
(Flux.ones32, nobias, Flux.zeros32), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
112+
(nobias, Flux.ones32, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
113+
)
114+
m1 = dm(b1)
115+
m2 = dm(b2)
116+
loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2))
117+
testdense(m1, be)
118+
end
119+
end
120+
end
121+
122+
@testset "Destructure" begin
123+
@testset "Bias type $bt" for bt in (zeros, nobias)
124+
m = dm(bt)
125+
p, re = destructure(m)
126+
testdense(re(p), bt)
127+
end
128+
129+
@testset "restructure in gradient" begin
130+
x = rand(Float32, 3, 1)
131+
m = dm(zeros)
132+
∇m = gradient(m -> sum(m(x)), m)[1]
133+
p, re = destructure(m)
134+
∇p = gradient-> sum(re(θ)(x)), p)[1]
135+
@test ∇p destructure(∇m)[1] rtol=1e-6
136+
end
137+
138+
@testset "destructure with buffers" begin
139+
p, re = destructure(BatchNorm(3))
140+
@test length(p) == 6
141+
142+
# https://github.com/FluxML/Flux.jl/issues/1727
143+
x = rand(Float32, 3, 4)
144+
y, back = Flux.pullback(x, p) do x, p
145+
vec(re(p)(x))
146+
end
147+
@test_nowarn back(y)
148+
b = back(y)
149+
@test size(b[1]) == size(x)
150+
@test size(b[2]) == size(p)
151+
end
152+
end
153+
154+
@testset "Train and test mode" begin
155+
mutable struct DummyLayer
156+
testing::Bool
157+
end
158+
Flux.testmode!(m::DummyLayer, testing=true) = (m.testing = testing; m)
159+
160+
c = Chain(DummyLayer(true))
161+
testmode!(c)
162+
@test c[1].testing
163+
trainmode!(c)
164+
@test !c[1].testing
165+
end

0 commit comments

Comments
 (0)