Skip to content

Commit e89a7f3

Browse files
committed
rename everything
1 parent c729260 commit e89a7f3

File tree

4 files changed

+52
-27
lines changed

4 files changed

+52
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
1414
ChainRulesCore = "1"
15-
Functors = "0.2.7"
15+
Functors = "0.2.8"
1616
julia = "1.6"
1717

1818
[extras]

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Such restrictions are also obeyed by this function for flattening a model:
4646

4747
```@docs
4848
Optimisers.destructure
49+
Optimisers.Restructure
4950
```
5051

5152
## Rule Definition

src/Optimisers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ using Functors: functor, fmap, isleaf
44
using LinearAlgebra
55

66
include("interface.jl")
7+
78
include("destructure.jl")
8-
export destructure
9+
export destructure, total, total2
910

1011
include("rules.jl")
1112
export Descent, ADAM, Momentum, Nesterov, RMSProp,

src/destructure.jl

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,42 @@ julia> re([10,20,30])
1919
```
2020
"""
2121
function destructure(x)
22-
flat, off, len = alpha(x)
23-
flat, Restucture(x, off, len)
22+
flat, off, len = _flatten(x)
23+
flat, Restructure(x, off, len)
2424
end
2525

26-
struct Restucture{T,S}
26+
"""
27+
Restructure(Model, ..., length)
28+
29+
This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
30+
new parameters from vector `p`. If the model is callable, then `re(x, p)` .
31+
32+
# Example
33+
```julia
34+
julia> using Flux, Optimisers
35+
36+
julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))
37+
([1, 3, 2, 4, 0, 0], Restructure(Dense, ..., 6))
38+
39+
julia> m = re(-4:1)
40+
Dense(2, 2, σ) # 6 parameters
41+
42+
julia> m([0.2, 0.3]) ≈ re([0.2, 0.3], -4:1)
43+
true
44+
```
45+
"""
46+
struct Restructure{T,S}
2747
model::T
2848
offsets::S
2949
length::Int
3050
end
31-
(re::Restucture)(flat) = beta(re.model, re.offsets, flat; len = re.length)
32-
Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
51+
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat; len = re.length)
52+
(re::Restructure)(x, flat::AbstractVector) = re(flat)(x)
53+
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
54+
Base.length(re::Restructure) = re.length
3355

3456
# This flattens a model, and returns a web of offsets for later use:
35-
function alpha(x)
57+
function _flatten(x)
3658
isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case
3759
arrays = AbstractVector[]
3860
len = Ref(0)
@@ -45,14 +67,14 @@ function alpha(x)
4567
reduce(vcat, arrays), off, len[]
4668
end
4769

48-
function ChainRulesCore.rrule(::typeof(alpha), x)
49-
flat, off, len = alpha(x)
50-
alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
51-
(flat, off, len), alpha_back
70+
function ChainRulesCore.rrule(::typeof(_flatten), x)
71+
flat, off, len = _flatten(x)
72+
_flatten_back((dflat, _)) = (NoT, _rebuild(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
73+
(flat, off, len), _flatten_back
5274
end
5375

5476
# This reconstructs either a model like x, or a gradient for it:
55-
function beta(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...)
77+
function _rebuild(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...)
5678
len == length(flat) || error("wrong length")
5779
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
5880
_getat(y, o, flat)
@@ -66,40 +88,41 @@ _getat(y::AbstractArray, o::Int, flat::AbstractVector) =
6688
function _trainable_biwalk(f, x, aux)
6789
ch, re = functor(typeof(x), x)
6890
au, _ = functor(typeof(x), aux)
69-
trainmap(f, ch, _trainable(x), au) |> re
91+
_trainmap(f, ch, _trainable(x), au) |> re
7092
end
7193

72-
function trainmap(f, ch, tr, aux)
73-
map(ch, tr, aux) do c, t, a
94+
function _trainmap(f, ch, tr, aux)
95+
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)??
7496
isnothing(t) ? c : f(t, a)
7597
end
7698
end
7799

78-
function _Tangent_biwalk(f, x, aux) # use with prune = true
100+
function _Tangent_biwalk(f, x, aux) # use with prune = NoT
79101
ch, re = functor(typeof(x), x)
80102
au, _ = functor(typeof(x), aux)
81-
y = trainmap(f, ch, _trainable(x), au)
103+
y = _trainmap(f, ch, _trainable(x), au)
82104
y isa Tuple{} && return NoT
83105
Tangent{typeof(x), typeof(y)}(y)
84106
end
85107

86-
function ChainRulesCore.rrule(::typeof(beta), x, off, flat; len)
108+
function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len)
87109
dflat = map!(zero, similar(flat, float(eltype(flat))), flat)
88-
beta_back(dx) = (NoT, NoT, NoT, gamma!(x, dx, off, dflat))
89-
beta(x, off, flat; len), beta_back
110+
_rebuild_back(dx) = (NoT, NoT, NoT, _accumulate!(x, dx, off, dflat))
111+
_rebuild(x, off, flat; len), _rebuild_back
90112
end
91113

92114
# This is the gradient of model reconstruction, accumulating duplicates:
93-
function gamma!(x, dx, off, flat::AbstractVector)
115+
function _accumulate!(x, dx, off, flat::AbstractVector)
94116
x′, _ = functor(typeof(x), x)
95117
dx′, _ = functor(typeof(x), dx)
96118
off′, _ = functor(typeof(x), off)
97-
foreach((xᵢ, dxᵢ, oᵢ) -> gamma!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
119+
foreach((xᵢ, dxᵢ, oᵢ) -> _accumulate!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
98120
flat
99121
end
100-
function gamma!(x, dx, off::Integer, flat::AbstractVector)
101-
@views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes, hence no fmap.
122+
function _accumulate!(x, dx, off::Integer, flat::AbstractVector)
123+
@views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes
102124
flat
103125
end
104-
gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing
105-
gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
126+
_accumulate!(x, dx::Zero, off, flat::AbstractVector) = nothing
127+
_accumulate!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
128+

0 commit comments

Comments
 (0)