Skip to content

Commit f7c1a7f

Browse files
committed
destructure, take II
1 parent 4155bcd commit f7c1a7f

File tree

6 files changed

+206
-3
lines changed

6 files changed

+206
-3
lines changed

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ optimiser to act on all suitable fields. To restrict this, define `trainable`:
4242
Optimisers.trainable
4343
```
4444

45+
Such restrictions are also obeyed by this function for flattening a model:
46+
47+
```@docs
48+
Optimisers.destructure
49+
```
50+
4551
## Rule Definition
4652

4753
```@docs

src/Optimisers.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ using Functors: functor, fmap, isleaf
44
using LinearAlgebra
55

66
include("interface.jl")
7-
include("rules.jl")
7+
include("destructure.jl")
8+
export destructure
89

10+
include("rules.jl")
911
export Descent, ADAM, Momentum, Nesterov, RMSProp,
1012
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
1113
WeightDecay, ClipGrad, ClipNorm, OptimiserChain

src/destructure.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
2+
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
3+
const NoT = NoTangent()
4+
5+
"""
6+
destructure([T], model) -> vector, reconstructor
7+
8+
Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
9+
to a `Vector{T}`, and returns also a function which reverses this transformation.
10+
Differentiable.
11+
"""
12+
function destructure(::Type{T}, x) where T
13+
flat, off = alpha!(x, T[])
14+
len = length(flat)
15+
# flat, newflat -> beta(x, off, newflat; len)
16+
flat, Restucture(x, off, len)
17+
end
18+
19+
struct Restucture{T,S}
20+
model::T
21+
offsets::S
22+
length::Int
23+
end
24+
(re::Restucture)(flat) = beta(re.model, re.offsets, flat; len = re.length)
25+
Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
26+
27+
# This flattens a model, and returns a web of offsets for later use:
28+
function alpha!(x, flat::AbstractVector)
29+
isempty(flat) || error("this won't work")
30+
isnumeric(x) && return append!(flat, x), 0 # trivial case
31+
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
32+
append!(flat, y)
33+
length(flat) - length(y)
34+
end
35+
flat, off
36+
end
37+
38+
function ChainRulesCore.rrule(::typeof(alpha!), x, flat)
39+
flat′, off = alpha!(x, flat)
40+
len = length(flat′)
41+
alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT)
42+
(flat′, off), alpha_back
43+
end
44+
45+
# This reconstructs either a model like x, or a gradient for it:
46+
function beta(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...)
47+
len == length(flat) || error("wrong length")
48+
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
49+
_getat(y, o, flat)
50+
end
51+
end
52+
53+
_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
54+
_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
55+
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
56+
57+
function _trainable_biwalk(f, x, aux)
58+
ch, re = functor(typeof(x), x)
59+
au, _ = functor(typeof(x), aux)
60+
trainmap(f, ch, _trainable(x), au) |> re
61+
end
62+
63+
function trainmap(f, ch, tr, aux)
64+
map(ch, tr, aux) do c, t, a
65+
isnothing(t) ? c : f(t, a)
66+
end
67+
end
68+
69+
function _Tangent_biwalk(f, x, aux) # use with prune = true
70+
ch, re = functor(typeof(x), x)
71+
au, _ = functor(typeof(x), aux)
72+
y = trainmap(f, ch, _trainable(x), au)
73+
y isa Tuple{} && return NoT
74+
Tangent{typeof(x), typeof(y)}(y)
75+
end
76+
# _Tangent_biwalk(f, x::Tuple{}, aux) = NoT
77+
78+
function ChainRulesCore.rrule(::typeof(beta), x, off, flat; len)
79+
dflat = map!(zero, similar(flat, float(eltype(flat))), flat)
80+
beta_back(dx) = (NoT, NoT, NoT, gamma!(x, dx, off, dflat))
81+
beta(x, off, flat; len), beta_back
82+
end
83+
84+
# This is the gradient of model reconstruction, accumulating duplicates:
85+
function gamma!(x, dx, off, flat::AbstractVector)
86+
x′, _ = functor(typeof(x), x)
87+
dx′, _ = functor(typeof(x), dx)
88+
off′, _ = functor(typeof(x), off)
89+
foreach((xᵢ, dxᵢ, oᵢ) -> gamma!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
90+
flat
91+
end
92+
function gamma!(x, dx, off::Integer, flat::AbstractVector)
93+
@views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes, hence no fmap.
94+
flat
95+
end
96+
gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing
97+
gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
98+
99+
# Least importantly, this infers the eltype if one is not given:
100+
destructure(x) = destructure(omega(x), x)
101+
function omega(x)
102+
T = Bool
103+
fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y
104+
T = promote_type(T, eltype(y))
105+
end
106+
T
107+
end
108+
ChainRulesCore.@non_differentiable omega(::Any)

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ trainable(x) = functor(x)[1]
7070

7171
_trainable(x) = _trainable(functor(x)[1], trainable(x))
7272
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
73-
_trainable(ch::Tuple, tr::Tuple) = tr
73+
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
7474
function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
7575
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple"
7676
map(c -> c in tr ? c : nothing, ch)

test/destructure.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
2+
m1 = collect(1:3.0)
3+
m2 = (collect(1:3.0), collect(4:6.0))
4+
m3 = (x = m1, y = sin, z = collect(4:6.0))
5+
m4 = (x = m1, y = m1, z = collect(4:6.0))
6+
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
7+
m6 = (a = m1, b = [4.0 + im], c = m1)
8+
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
9+
10+
@testset "flatten & restore" begin
11+
@test destructure(Int, m1)[1] isa Vector{Int}
12+
@test destructure(m1)[1] isa Vector{Float64}
13+
14+
@test destructure(m1)[1] == 1:3
15+
@test destructure(m2)[1] == 1:6
16+
@test destructure(m3)[1] == 1:6
17+
@test destructure(m4)[1] == 1:6
18+
@test destructure(m5)[1] == vcat(1:6, 4:6)
19+
@test destructure(m6)[1] == vcat(1:3, 4 + im)
20+
21+
@test destructure(m1)[2](7:9) == [7,8,9]
22+
@test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9])
23+
@test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9])
24+
m4′ = destructure(m4)[2](4:9)
25+
@test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9])
26+
@test m4′.x === m4′.y
27+
m5′ = destructure(m5)[2](reverse(1:9))
28+
@test m5′.a[1].x === m5′.b[1]
29+
@test m5′.b[2] === false
30+
m6′ = destructure(m6)[2]((4:7) .+ (1:4) .* im)
31+
@test m6′.a isa Vector{Float64}
32+
@test m6′.a == 4:6
33+
@test m6′.a === m6′.c
34+
@test m6′.b == [7 + 4im]
35+
36+
@test destructure(m7)[1] == 1:3
37+
m7′ = destructure(m7)[2]([10,20,30])
38+
@test m7′.a == (sin, [10,20,30])
39+
@test m7′.b == (cos, [4,5,6])
40+
@test m7′.c == (tan, [7,8,9])
41+
42+
@test_throws Exception destructure(m7)[2]([10,20])
43+
@test_throws Exception destructure(m7)[2]([10,20,30,40])
44+
end
45+
46+
@testset "gradient of flatten" begin
47+
@test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
48+
@test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
49+
@test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
50+
@test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
51+
@test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])
52+
53+
g5 = gradient(m -> destructure(m)[1][3], m5)[1]
54+
@test g5.a[1].x == [0,0,1]
55+
@test g5.a[2] === nothing
56+
57+
g6 = gradient(m -> imag(destructure(m)[1][4]), m6)[1]
58+
@test g6.a == [0,0,0]
59+
@test g6.a isa Vector{Float64}
60+
@test g6.b == [0+im]
61+
end
62+
63+
@testset "gradient of rebuild" begin
64+
re1 = destructure(m1)[2]
65+
@test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
66+
re2 = destructure(m2)[2]
67+
@test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0]
68+
re3 = destructure(m3)[2]
69+
@test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0]
70+
@test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0]
71+
72+
re4 = destructure(m4)[2]
73+
@test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0]
74+
@test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0]
75+
@test gradient(rand(6)) do x
76+
m = re4(x)
77+
m.x[1] + 2*m.y[2] + 3*m.z[3]
78+
end[1] == [1,2,0, 0,0,3]
79+
80+
re7 = destructure(m7)[2]
81+
@test gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1]
82+
@test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
83+
@test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]
84+
end

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,11 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
164164
@test_throws ArgumentError Optimisers.setup(ADAMW(), m2)
165165
end
166166

167-
@info "finished feature testing"
168167
end
168+
@testset verbose=true "Optimisation Rules" begin
169+
include("destructure.jl")
170+
end
171+
@info "finished feature testing"
169172
@testset verbose=true "Optimisation Rules" begin
170173
include("rules.jl")
171174
end

0 commit comments

Comments
 (0)