Skip to content

Commit 7694a9c

Browse files
params
1 parent 4e45416 commit 7694a9c

File tree

4 files changed

+39
-20
lines changed

4 files changed

+39
-20
lines changed

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Zygote, MacroTools, Juno, Reexport
88
using MacroTools: @forward
99
@reexport using NNlib
1010
using Zygote: Params, @adjoint, gradient, pullback, @nograd
11+
using Functors: @functor, functor, fmap, isleaf
1112
export gradient
1213

1314
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,

src/functor.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import Adapt: adapt, adapt_storage
22
using LinearAlgebra: Cholesky
33
using Zygote: IdSet
4-
import Functors: Functors, @functor, functor, fmap, isleaf
54
using SparseArrays: AbstractSparseArray
65

76
trainable(m) = functor(m)[1]
@@ -38,23 +37,31 @@ Possible values include:
3837
"""
3938
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
4039

41-
# push!(::Params, x) automatically discards already seen arrays
42-
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
40+
# # push!(::Params, x) automatically discards already seen arrays
41+
# params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
4342

44-
function params!(p::Params, x, seen = IdSet())
45-
x in seen && return
46-
push!(seen, x)
47-
for child in trainable(x)
48-
params!(p, child, seen)
49-
end
50-
end
43+
# function params!(p::Params, x, seen = IdSet())
44+
# x in seen && return
45+
# push!(seen, x)
46+
# for child in trainable(x)
47+
# params!(p, child, seen)
48+
# end
49+
# end
50+
51+
# function params(m...)
52+
# ps = Params()
53+
# params!(ps, m)
54+
# return ps
55+
# end
5156

5257
function params(m...)
5358
ps = Params()
54-
params!(ps, m)
59+
collect_params!(ps, m)
5560
return ps
5661
end
5762

63+
64+
5865
function loadparams!(m, xs)
5966
for (p, x) in zip(params(m), xs)
6067
size(p) == size(x) ||

src/utils.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,23 +675,33 @@ modifications to the weight vector (for example, with a hypernetwork).
675675
"""
676676
function destructure(m)
677677
xs = Zygote.Buffer([])
678+
collect_params!(xs, m)
679+
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
680+
end
681+
682+
function collect_params!(xs, m)
678683
filter = (x, c) -> any(y -> c === y, trainable(x))
679684
walk = filtered_walk(filter)
680685
fmap(m; walk) do x
681686
x isa AbstractArray{<:Number} && push!(xs, x)
682687
return x
683688
end
684-
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
685689
end
686690

687691
function filtered_walk(filter)
692+
seen = IdSet()
693+
688694
function walk(f, x)
695+
x in seen && return x
696+
push!(seen, x)
697+
689698
children, reconstruct = functor(x)
690699
mappedchildren = map(children) do c
691700
filter(x, c) ? f(c) : c
692701
end
693702
reconstruct(mappedchildren)
694703
end
704+
695705
return walk
696706
end
697707

test/utils.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,22 +409,23 @@ end
409409
∇p = gradient-> sum(re(θ)(x)), p)[1]
410410
# @show size(∇p)
411411
# @show size(destructure(∇m)[1])
412-
# @show norm(∇p - destructure(∇m)[1])
412+
@show norm(∇p - destructure(∇m)[1])
413413
@test ∇p destructure(∇m)[1] atol=1e-4
414414
end
415415

416416
@testset "destructure with buffers" begin
417-
p, re = destructure(BatchNorm(10))
418-
@test length(p) == 20
417+
p, re = destructure(BatchNorm(3))
418+
@test length(p) == 6
419419

420420
# https://github.com/FluxML/Flux.jl/issues/1727
421-
x = rand(Float32, 2, 3)
422-
gs, back = Flux.pullback(x, p) do x, p
421+
x = rand(Float32, 3, 4)
422+
y, back = Flux.pullback(x, p) do x, p
423423
vec(re(p)(x))
424424
end
425-
@test_nowarn b = back(a)
426-
@test b[1] == size(x)
427-
@test b[2] == size(p)
425+
@test_nowarn back(y)
426+
b = back(y)
427+
@test size(b[1]) == size(x)
428+
@test size(b[2]) == size(p)
428429
end
429430
end
430431
end

0 commit comments

Comments
 (0)