@@ -8,11 +8,18 @@ const NoT = NoTangent()
88Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
99to a `Vector{T}`, and returns also a function which reverses this transformation.
1010Differentiable.
11+
12+ # Example
13+ ```jldoctest
14+ julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0])))
15+ ([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))
16+
17+ julia> re([10,20,30])
18+ (x = [10.0, 20.0], y = (sin, [30.0]))
19+ ```
1120"""
1221function destructure (:: Type{T} , x) where T
13- flat, off = alpha! (x, T[])
14- len = length (flat)
15- # flat, newflat -> beta(x, off, newflat; len)
22+ flat, off, len = alpha! (x, T[])
1623 flat, Restucture (x, off, len)
1724end
1825
@@ -32,14 +39,13 @@ function alpha!(x, flat::AbstractVector)
3239 append! (flat, y)
3340 length (flat) - length (y)
3441 end
35- flat, off
42+ flat, off, length (flat)
3643end
3744
3845function ChainRulesCore. rrule (:: typeof (alpha!), x, flat)
39- flat′, off = alpha! (x, flat)
40- len = length (flat′)
46+ flat′, off, len = alpha! (x, flat)
4147 alpha_back ((dflat, _)) = (NoT, beta (x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT)
42- (flat′ , off), alpha_back
48+ (flat, off, len ), alpha_back
4349end
4450
4551# This reconstructs either a model like x, or a gradient for it:
@@ -73,7 +79,6 @@ function _Tangent_biwalk(f, x, aux) # use with prune = true
7379 y isa Tuple{} && return NoT
7480 Tangent {typeof(x), typeof(y)} (y)
7581end
76- # _Tangent_biwalk(f, x::Tuple{}, aux) = NoT
7782
7883function ChainRulesCore. rrule (:: typeof (beta), x, off, flat; len)
7984 dflat = map! (zero, similar (flat, float (eltype (flat))), flat)
0 commit comments