@@ -37,6 +37,140 @@ Possible values include:
3737"""
3838trainmode! (m, mode = true ) = mode isa Bool ? testmode! (m, ! mode) : testmode! (m, mode)
3939
40+
41+ # Flattening models to weight vectors, and back
42+
43+ function _restructure (m, xs)
44+ i = 0
45+ filter = (x, c) -> any (y -> c === y, trainable (x))
46+ walk = filtered_walk (filter)
47+ m̄ = fmap (m; walk) do x
48+ x isa AbstractArray{<: Number } || return x
49+ x = reshape (xs[i .+ (1 : length (x))], size (x))
50+ i += length (x)
51+ return x
52+ end
53+ length (xs) == i || @warn " Expected $(i) params, got $(length (xs)) "
54+ return m̄
55+ end
56+
57+ @adjoint function _restructure (m, xs)
58+ m̄, numel = _restructure (m, xs), length (xs)
59+ function _restructure_pullback (dm)
60+ xs′ = destructure (dm)[1 ]
61+ numel == length (xs′) || @warn " Expected $(numel) params, got $(length (xs′)) "
62+ return (nothing , xs′)
63+ end
64+ return m̄, _restructure_pullback
65+ end
66+
67+ """
68+ destructure(m)
69+ Flatten a model's parameters into a single weight vector.
70+ julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
71+ Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
72+ julia> θ, re = destructure(m);
73+ julia> θ
74+ 67-element Vector{Float32}:
75+ -0.1407104
76+ ...
77+ The second return value `re` allows you to reconstruct the original network after making
78+ modifications to the weight vector (for example, with a hypernetwork).
79+ julia> re(θ .* 2)
80+ Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
81+ """
82+ function destructure (m)
83+ xs = Zygote. Buffer ([])
84+ collect_params! (xs, m)
85+ return vcat (vec .(copy (xs))... ), p -> _restructure (m, p)
86+ end
87+
88+ function collect_params! (xs, m)
89+ filter = (x, c) -> any (y -> c === y, trainable (x))
90+ walk = filtered_walk (filter)
91+ fmap (m; walk) do x
92+ x isa AbstractArray{<: Number } && push! (xs, x)
93+ return x
94+ end
95+ end
96+
97+ function filtered_walk (filter)
98+ seen = IdSet ()
99+
100+ function walk (f, x)
101+ x in seen && return x
102+ push! (seen, x)
103+
104+ children, reconstruct = functor (x)
105+ mappedchildren = map (children) do c
106+ filter (x, c) ? f (c) : c
107+ end
108+ reconstruct (mappedchildren)
109+ end
110+
111+ return walk
112+ end
113+
114+
115+ """
116+ params(m...)
117+
118+ Collect trainable parameters (a.k.a. numerical arrays)
119+ from the input model(s) `m` into a [`Zygote.Params`](@ref) object.
120+
121+ Only the parameters that can be reached by recursion
122+ on the [`trainable`](@ref) children of
123+ the tree with root `m` are collected.
124+
125+ # Usage
126+
127+ ```julia-repl
128+ julia> m = Dense(ones(2, 3), zeros(2))
129+ Dense(3, 2) # 8 parameters
130+
131+ julia> ps = Flux.params(m)
132+ Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
133+
134+ julia> x = ones(3)
135+ 3-element Vector{Float64}:
136+ 1.0
137+ 1.0
138+ 1.0
139+
140+ julia> gs = gradient(() -> sum(2 .* m(x)), ps)
141+ Grads(...)
142+
143+ julia> gs[m.weight]
144+ 2×3 Matrix{Float64}:
145+ 2.0 2.0 2.0
146+ 2.0 2.0 2.0
147+ ```
148+ """
149+ function params end
150+
151+ # # TODO This causes some test regressions. Why?
152+ # function params(m...)
153+ # ps = Params()
154+ # collect_params!(ps, m)
155+ # return ps
156+ # end
157+
158+ params! (p:: Params , x:: AbstractArray{<:Number} , seen = IdSet ()) = push! (p, x)
159+
160+ function params! (p:: Params , x, seen = IdSet ())
161+ x in seen && return
162+ push! (seen, x)
163+ for child in trainable (x)
164+ params! (p, child, seen)
165+ end
166+ end
167+
168+ function params (m... )
169+ ps = Params ()
170+ params! (ps, m)
171+ return ps
172+ end
173+
40174function loadparams! (m, xs)
41175 for (p, x) in zip (params (m), xs)
42176 size (p) == size (x) ||
0 commit comments