Skip to content

Commit 4e45416

Browse files
destructure returns only trainable params
1 parent 067f0b4 commit 4e45416

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

src/functor.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Possible values include:
3838
"""
3939
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
4040

41+
# push!(::Params, x) automatically discards already seen arrays
4142
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
4243

4344
function params!(p::Params, x, seen = IdSet())

src/utils.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -630,9 +630,11 @@ end
630630

631631
function _restructure(m, xs)
632632
i = 0
633-
= fmap(m) do x
634-
x isa AbstractArray || return x
635-
x = reshape(xs[i.+(1:length(x))], size(x))
633+
filter = (x, c) -> any(y -> c === y, trainable(x))
634+
walk = filtered_walk(filter)
635+
= fmap(m; walk) do x
636+
x isa AbstractArray{<:Number} || return x
637+
x = reshape(xs[i .+ (1:length(x))], size(x))
636638
i += length(x)
637639
return x
638640
end
@@ -673,13 +675,28 @@ modifications to the weight vector (for example, with a hypernetwork).
673675
"""
674676
function destructure(m)
675677
xs = Zygote.Buffer([])
676-
fmap(m) do x
677-
x isa AbstractArray && push!(xs, x)
678+
filter = (x, c) -> any(y -> c === y, trainable(x))
679+
walk = filtered_walk(filter)
680+
fmap(m; walk) do x
681+
x isa AbstractArray{<:Number} && push!(xs, x)
678682
return x
679683
end
680684
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
681685
end
682686

687+
function filtered_walk(filter)
688+
function walk(f, x)
689+
children, reconstruct = functor(x)
690+
mappedchildren = map(children) do c
691+
filter(x, c) ? f(c) : c
692+
end
693+
reconstruct(mappedchildren)
694+
end
695+
return walk
696+
end
697+
698+
@functor Base.RefValue
699+
683700
# Other
684701

685702
"""

test/utils.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,24 @@ end
407407
∇m = gradient(m -> sum(m(x)), m)[1]
408408
p, re = destructure(m)
409409
∇p = gradient-> sum(re(θ)(x)), p)[1]
410-
@test ∇p destructure(∇m)[1]
410+
# @show size(∇p)
411+
# @show size(destructure(∇m)[1])
412+
# @show norm(∇p - destructure(∇m)[1])
413+
@test ∇p destructure(∇m)[1] atol=1e-4
414+
end
415+
416+
@testset "destructure with buffers" begin
417+
p, re = destructure(BatchNorm(10))
418+
@test length(p) == 20
419+
420+
# https://github.com/FluxML/Flux.jl/issues/1727
421+
x = rand(Float32, 2, 3)
422+
gs, back = Flux.pullback(x, p) do x, p
423+
vec(re(p)(x))
424+
end
425+
@test_nowarn b = back(a)
426+
@test b[1] == size(x)
427+
@test b[2] == size(p)
411428
end
412429
end
413430
end

0 commit comments

Comments
 (0)