Skip to content

Commit a649870

Browse files
update
1 parent 5927ef1 commit a649870

File tree

5 files changed

+306
-194
lines changed

5 files changed

+306
-194
lines changed

src/functor.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,140 @@ Possible values include:
3737
"""
3838
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
3939

40+
41+
# Flattening models to weight vectors, and back
42+
43+
function _restructure(m, xs)
44+
i = 0
45+
filter = (x, c) -> any(y -> c === y, trainable(x))
46+
walk = filtered_walk(filter)
47+
= fmap(m; walk) do x
48+
x isa AbstractArray{<:Number} || return x
49+
x = reshape(xs[i .+ (1:length(x))], size(x))
50+
i += length(x)
51+
return x
52+
end
53+
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
54+
return
55+
end
56+
57+
@adjoint function _restructure(m, xs)
58+
m̄, numel = _restructure(m, xs), length(xs)
59+
function _restructure_pullback(dm)
60+
xs′ = destructure(dm)[1]
61+
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
62+
return (nothing, xs′)
63+
end
64+
return m̄, _restructure_pullback
65+
end
66+
67+
"""
68+
destructure(m)
69+
Flatten a model's parameters into a single weight vector.
70+
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
71+
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
72+
julia> θ, re = destructure(m);
73+
julia> θ
74+
67-element Vector{Float32}:
75+
-0.1407104
76+
...
77+
The second return value `re` allows you to reconstruct the original network after making
78+
modifications to the weight vector (for example, with a hypernetwork).
79+
julia> re(θ .* 2)
80+
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
81+
"""
82+
function destructure(m)
83+
xs = Zygote.Buffer([])
84+
collect_params!(xs, m)
85+
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
86+
end
87+
88+
function collect_params!(xs, m)
89+
filter = (x, c) -> any(y -> c === y, trainable(x))
90+
walk = filtered_walk(filter)
91+
fmap(m; walk) do x
92+
x isa AbstractArray{<:Number} && push!(xs, x)
93+
return x
94+
end
95+
end
96+
97+
function filtered_walk(filter)
98+
seen = IdSet()
99+
100+
function walk(f, x)
101+
x in seen && return x
102+
push!(seen, x)
103+
104+
children, reconstruct = functor(x)
105+
mappedchildren = map(children) do c
106+
filter(x, c) ? f(c) : c
107+
end
108+
reconstruct(mappedchildren)
109+
end
110+
111+
return walk
112+
end
113+
114+
115+
"""
116+
params(m...)
117+
118+
Collect trainable parameters (a.k.a. numerical arrays)
119+
from the input model(s) `m` into a [`Zygote.Params`](@ref) object.
120+
121+
Only the parameters that can be reached by recursion
122+
on the [`trainable`](@ref) children of
123+
the tree with root `m` are collected.
124+
125+
# Usage
126+
127+
```julia-repl
128+
julia> m = Dense(ones(2, 3), zeros(2))
129+
Dense(3, 2) # 8 parameters
130+
131+
julia> ps = Flux.params(m)
132+
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
133+
134+
julia> x = ones(3)
135+
3-element Vector{Float64}:
136+
1.0
137+
1.0
138+
1.0
139+
140+
julia> gs = gradient(() -> sum(2 .* m(x)), ps)
141+
Grads(...)
142+
143+
julia> gs[m.weight]
144+
2×3 Matrix{Float64}:
145+
2.0 2.0 2.0
146+
2.0 2.0 2.0
147+
```
148+
"""
149+
function params end
150+
151+
## TODO This causes some test regressions. Why?
152+
# function params(m...)
153+
# ps = Params()
154+
# collect_params!(ps, m)
155+
# return ps
156+
# end
157+
158+
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
159+
160+
function params!(p::Params, x, seen = IdSet())
161+
x in seen && return
162+
push!(seen, x)
163+
for child in trainable(x)
164+
params!(p, child, seen)
165+
end
166+
end
167+
168+
function params(m...)
169+
ps = Params()
170+
params!(ps, m)
171+
return ps
172+
end
173+
40174
function loadparams!(m, xs)
41175
for (p, x) in zip(params(m), xs)
42176
size(p) == size(x) ||

src/utils.jl

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -626,128 +626,6 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
626626
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
627627
end
628628

629-
# Flattening models to weight vectors, and back
630-
631-
function _restructure(m, xs)
632-
i = 0
633-
filter = (x, c) -> any(y -> c === y, trainable(x))
634-
walk = filtered_walk(filter)
635-
= fmap(m; walk) do x
636-
x isa AbstractArray{<:Number} || return x
637-
x = reshape(xs[i .+ (1:length(x))], size(x))
638-
i += length(x)
639-
return x
640-
end
641-
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
642-
return
643-
end
644-
645-
@adjoint function _restructure(m, xs)
646-
m̄, numel = _restructure(m, xs), length(xs)
647-
function _restructure_pullback(dm)
648-
xs′ = destructure(dm)[1]
649-
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
650-
return (nothing, xs′)
651-
end
652-
return m̄, _restructure_pullback
653-
end
654-
655-
"""
656-
destructure(m)
657-
658-
Flatten a model's parameters into a single weight vector.
659-
660-
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
661-
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
662-
663-
julia> θ, re = destructure(m);
664-
665-
julia> θ
666-
67-element Vector{Float32}:
667-
-0.1407104
668-
...
669-
670-
The second return value `re` allows you to reconstruct the original network after making
671-
modifications to the weight vector (for example, with a hypernetwork).
672-
673-
julia> re(θ .* 2)
674-
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
675-
"""
676-
function destructure(m)
677-
xs = Zygote.Buffer([])
678-
collect_params!(xs, m)
679-
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
680-
end
681-
682-
function collect_params!(xs, m)
683-
filter = (x, c) -> any(y -> c === y, trainable(x))
684-
walk = filtered_walk(filter)
685-
fmap(m; walk) do x
686-
x isa AbstractArray{<:Number} && push!(xs, x)
687-
return x
688-
end
689-
end
690-
691-
function filtered_walk(filter)
692-
seen = IdSet()
693-
694-
function walk(f, x)
695-
x in seen && return x
696-
push!(seen, x)
697-
698-
children, reconstruct = functor(x)
699-
mappedchildren = map(children) do c
700-
filter(x, c) ? f(c) : c
701-
end
702-
reconstruct(mappedchildren)
703-
end
704-
705-
return walk
706-
end
707-
708-
"""
709-
params(m...)
710-
711-
Collect trainable parameters (a.k.a. numerical arrays)
712-
from the input model(s) `m` into a [`Zygote.Params`](@ref) object.
713-
714-
Only the parameters that can be reached by recursion
715-
on the [`trainable`](@ref) children of
716-
the tree with root `m` are collected.
717-
718-
# Usage
719-
720-
```julia
721-
julia> m = Dense(ones(2, 3), zeros(2))
722-
Dense(3, 2) # 8 parameters
723-
724-
julia> ps = Flux.params(m)
725-
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
726-
727-
julia> x = ones(3)
728-
3-element Vector{Float64}:
729-
1.0
730-
1.0
731-
1.0
732-
733-
julia> gs = gradient(() -> sum(2 .* m(x)), ps)
734-
Grads(...)
735-
736-
julia> gs[m.weight]
737-
2×3 Matrix{Float64}:
738-
2.0 2.0 2.0
739-
2.0 2.0 2.0
740-
```
741-
"""
742-
function params(m...)
743-
ps = Params()
744-
collect_params!(ps, m)
745-
return ps
746-
end
747-
748-
749-
@functor Base.RefValue
750-
751629

752630
# Other
753631

0 commit comments

Comments
 (0)