Skip to content

Commit e804b60

Browse files
docs and tests
1 parent 7694a9c commit e804b60

File tree

3 files changed

+98
-25
lines changed

3 files changed

+98
-25
lines changed

src/functor.jl

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,31 +37,6 @@ Possible values include:
3737
"""
3838
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
3939

40-
# # push!(::Params, x) automatically discards already seen arrays
41-
# params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
42-
43-
# function params!(p::Params, x, seen = IdSet())
44-
# x in seen && return
45-
# push!(seen, x)
46-
# for child in trainable(x)
47-
# params!(p, child, seen)
48-
# end
49-
# end
50-
51-
# function params(m...)
52-
# ps = Params()
53-
# params!(ps, m)
54-
# return ps
55-
# end
56-
57-
function params(m...)
58-
ps = Params()
59-
collect_params!(ps, m)
60-
return ps
61-
end
62-
63-
64-
6540
function loadparams!(m, xs)
6641
for (p, x) in zip(params(m), xs)
6742
size(p) == size(x) ||

src/utils.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,8 +705,50 @@ function filtered_walk(filter)
705705
return walk
706706
end
707707

708+
"""
709+
params(m...)
710+
711+
Collect trainable parameters (a.k.a. numerical arrays)
712+
from the input model(s) `m` into a [`Zygote.Params`](@ref) object.
713+
714+
Only the parameters that can be reached by recursion
715+
on the [`trainable`](@ref) children of
716+
the tree with root `m` are collected.
717+
718+
# Usage
719+
720+
```julia
721+
julia> m = Dense(ones(2, 3), zeros(2))
722+
Dense(3, 2) # 8 parameters
723+
724+
julia> ps = Flux.params(m)
725+
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
726+
727+
julia> x = ones(3)
728+
3-element Vector{Float64}:
729+
1.0
730+
1.0
731+
1.0
732+
733+
julia> gs = gradient(() -> sum(2 .* m(x)), ps)
734+
Grads(...)
735+
736+
julia> gs[m.weight]
737+
2×3 Matrix{Float64}:
738+
2.0 2.0 2.0
739+
2.0 2.0 2.0
740+
```
741+
"""
742+
function params(m...)
743+
ps = Params()
744+
collect_params!(ps, m)
745+
return ps
746+
end
747+
748+
708749
@functor Base.RefValue
709750

751+
710752
# Other
711753

712754
"""

test/utils.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,62 @@ end
216216
r = Any[nothing,m]
217217
r[1] = r
218218
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)]
219+
220+
@testset "use params in gradient context" begin
221+
m = Chain(Dense(3,2), Dense(2,2))
222+
ps = Flux.params(m)
223+
gs = gradient(() -> sum(sum(p) for p in Flux.params(m)), ps)
224+
for p in ps
225+
@test gs[p] ones(size(p))
226+
end
227+
228+
w1, w2 = rand(2), rand(2)
229+
ps = Flux.params(w1, w2)
230+
gs = gradient(() -> sum(sum(p) for p in Flux.params(w1, w2)), ps)
231+
for p in ps
232+
@test gs[p] ones(size(p))
233+
end
234+
235+
# BROKEN TESTS
236+
m = Chain(Dense(3,2), Dense(2,2))
237+
@test_broken gradient(m -> sum(params(m)[1]), m) != (nothing, )
238+
@test_broken gradient(m -> sum(params(m)[1]), m) != (nothing, )
239+
240+
gs = gradient(() -> sum(params(m)[1]), params(m))
241+
@test_broken gs[params(m)[1]] !== nothing
242+
243+
# Tests from https://github.com/FluxML/Flux.jl/pull/1614
244+
m = Dense(3, 2)
245+
ps = Flux.params(m)
246+
data = rand(Float32, 3, 5)
247+
loss(m, x) = sum(m(x).^2)
248+
249+
g1 = gradient(Flux.params(m)) do
250+
loss(m, data)
251+
end
252+
g2 = gradient(Flux.params(m)) do
253+
ps = Flux.params(m) # just creating params without using them
254+
loss(m, data)
255+
end
256+
g3 = gradient(Flux.params(m)) do
257+
ps = Flux.params(m)
258+
loss(m, data) + sum(sum(p) for p in ps)
259+
end
260+
g4 = gradient(Flux.params(m)) do
261+
loss(m, data) + sum(sum(p) for p in ps)
262+
end
263+
g5 = gradient(Flux.params(m)) do
264+
sum(Flux.params(m)[1]) + sum(Flux.params(m)[2])
265+
end
266+
g6 = gradient(Flux.params(m)) do
267+
sum(ps[1]) + sum(ps[2])
268+
end
269+
@test g2[m.weight] == g1[m.weight]
270+
@test g3[m.weight] == g1[m.weight] .+ 1
271+
@test g4[m.weight] == g1[m.weight] .+ 1
272+
@test_broken g5[m.weight] .== 1 # TODO regression with respect to master
273+
@test_broken g6[m.weight] .== 1 # Not a regression, broken on master
274+
end
219275
end
220276

221277
@testset "Basic Stacking" begin

0 commit comments

Comments
 (0)