@@ -19,20 +19,42 @@ julia> re([10,20,30])
1919```
2020"""
2121function destructure (x)
22- flat, off, len = alpha (x)
23- flat, Restucture (x, off, len)
22+ flat, off, len = _flatten (x)
23+ flat, Restructure (x, off, len)
2424end
2525
26- struct Restucture{T,S}
26+ """
27+ Restructure(Model, ..., length)
28+
29+ This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
30+ new parameters from vector `p`. If the model is callable, then `re(x, p)` .
31+
32+ # Example
33+ ```julia
34+ julia> using Flux, Optimisers
35+
36+ julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))
37+ ([1, 3, 2, 4, 0, 0], Restructure(Dense, ..., 6))
38+
39+ julia> m = re(-4:1)
40+ Dense(2, 2, σ) # 6 parameters
41+
42+ julia> m([0.2, 0.3]) ≈ re([0.2, 0.3], -4:1)
43+ true
44+ ```
45+ """
46+ struct Restructure{T,S}
2747 model:: T
2848 offsets:: S
2949 length:: Int
3050end
31- (re:: Restucture )(flat) = beta (re. model, re. offsets, flat; len = re. length)
32- Base. show (io:: IO , re:: Restucture{T} ) where T = print (io, " Restructure(" , T. name. name, " , ..., " , re. length, " )" )
51+ (re:: Restructure )(flat:: AbstractVector ) = _rebuild (re. model, re. offsets, flat; len = re. length)
52+ (re:: Restructure )(x, flat:: AbstractVector ) = re (flat)(x)
53+ Base. show (io:: IO , re:: Restructure{T} ) where T = print (io, " Restructure(" , T. name. name, " , ..., " , re. length, " )" )
54+ Base. length (re:: Restructure ) = re. length
3355
3456# This flattens a model, and returns a web of offsets for later use:
35- function alpha (x)
57+ function _flatten (x)
3658 isnumeric (x) && return vcat (vec (x)), 0 , length (x) # trivial case
3759 arrays = AbstractVector[]
3860 len = Ref (0 )
@@ -45,14 +67,14 @@ function alpha(x)
4567 reduce (vcat, arrays), off, len[]
4668end
4769
48- function ChainRulesCore. rrule (:: typeof (alpha ), x)
49- flat, off, len = alpha (x)
50- alpha_back ((dflat, _)) = (NoT, beta (x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
51- (flat, off, len), alpha_back
70+ function ChainRulesCore. rrule (:: typeof (_flatten ), x)
71+ flat, off, len = _flatten (x)
72+ _flatten_back ((dflat, _)) = (NoT, _rebuild (x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
73+ (flat, off, len), _flatten_back
5274end
5375
5476# This reconstructs either a model like x, or a gradient for it:
55- function beta (x, off, flat:: AbstractVector ; len, walk = _trainable_biwalk, kw... )
77+ function _rebuild (x, off, flat:: AbstractVector ; len, walk = _trainable_biwalk, kw... )
5678 len == length (flat) || error (" wrong length" )
5779 fmap (x, off; exclude = isnumeric, walk, kw... ) do y, o
5880 _getat (y, o, flat)
@@ -66,40 +88,41 @@ _getat(y::AbstractArray, o::Int, flat::AbstractVector) =
6688function _trainable_biwalk (f, x, aux)
6789 ch, re = functor (typeof (x), x)
6890 au, _ = functor (typeof (x), aux)
69- trainmap (f, ch, _trainable (x), au) |> re
91+ _trainmap (f, ch, _trainable (x), au) |> re
7092end
7193
72- function trainmap (f, ch, tr, aux)
73- map (ch, tr, aux) do c, t, a
94+ function _trainmap (f, ch, tr, aux)
95+ map (ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)??
7496 isnothing (t) ? c : f (t, a)
7597 end
7698end
7799
78- function _Tangent_biwalk (f, x, aux) # use with prune = true
100+ function _Tangent_biwalk (f, x, aux) # use with prune = NoT
79101 ch, re = functor (typeof (x), x)
80102 au, _ = functor (typeof (x), aux)
81- y = trainmap (f, ch, _trainable (x), au)
103+ y = _trainmap (f, ch, _trainable (x), au)
82104 y isa Tuple{} && return NoT
83105 Tangent {typeof(x), typeof(y)} (y)
84106end
85107
86- function ChainRulesCore. rrule (:: typeof (beta ), x, off, flat; len)
108+ function ChainRulesCore. rrule (:: typeof (_rebuild ), x, off, flat; len)
87109 dflat = map! (zero, similar (flat, float (eltype (flat))), flat)
88- beta_back (dx) = (NoT, NoT, NoT, gamma ! (x, dx, off, dflat))
89- beta (x, off, flat; len), beta_back
110+ _rebuild_back (dx) = (NoT, NoT, NoT, _accumulate ! (x, dx, off, dflat))
111+ _rebuild (x, off, flat; len), _rebuild_back
90112end
91113
92114# This is the gradient of model reconstruction, accumulating duplicates:
93- function gamma ! (x, dx, off, flat:: AbstractVector )
115+ function _accumulate ! (x, dx, off, flat:: AbstractVector )
94116 x′, _ = functor (typeof (x), x)
95117 dx′, _ = functor (typeof (x), dx)
96118 off′, _ = functor (typeof (x), off)
97- foreach ((xᵢ, dxᵢ, oᵢ) -> gamma ! (xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
119+ foreach ((xᵢ, dxᵢ, oᵢ) -> _accumulate ! (xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
98120 flat
99121end
100- function gamma ! (x, dx, off:: Integer , flat:: AbstractVector )
101- @views flat[off .+ (1 : length (x))] .+ = dx # must visit all tied nodes, hence no fmap.
122+ function _accumulate ! (x, dx, off:: Integer , flat:: AbstractVector )
123+ @views flat[off .+ (1 : length (x))] .+ = dx # must visit all tied nodes
102124 flat
103125end
104- gamma! (x, dx:: Zero , off, flat:: AbstractVector ) = nothing
105- gamma! (x, dx:: Zero , off:: Integer , flat:: AbstractVector ) = nothing # ambiguity
126+ _accumulate! (x, dx:: Zero , off, flat:: AbstractVector ) = nothing
127+ _accumulate! (x, dx:: Zero , off:: Integer , flat:: AbstractVector ) = nothing # ambiguity
128+
0 commit comments