Skip to content

Commit b70875f

Browse files
committed
tidy
1 parent 5a18607 commit b70875f

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

src/destructure.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@ const NoT = NoTangent()
88
Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
99
to a `Vector{T}`, and returns also a function which reverses this transformation.
1010
Differentiable.
11+
12+
# Example
13+
```jldoctest
14+
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0])))
15+
([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))
16+
17+
julia> re([10,20,30])
18+
(x = [10.0, 20.0], y = (sin, [30.0]))
19+
```
1120
"""
1221
function destructure(::Type{T}, x) where T
13-
flat, off = alpha!(x, T[])
14-
len = length(flat)
15-
# flat, newflat -> beta(x, off, newflat; len)
22+
flat, off, len = alpha!(x, T[])
1623
flat, Restucture(x, off, len)
1724
end
1825

@@ -32,14 +39,13 @@ function alpha!(x, flat::AbstractVector)
3239
append!(flat, y)
3340
length(flat) - length(y)
3441
end
35-
flat, off
42+
flat, off, length(flat)
3643
end
3744

3845
function ChainRulesCore.rrule(::typeof(alpha!), x, flat)
39-
flat′, off = alpha!(x, flat)
40-
len = length(flat′)
46+
flat′, off, len = alpha!(x, flat)
4147
alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT)
42-
(flat, off), alpha_back
48+
(flat, off, len), alpha_back
4349
end
4450

4551
# This reconstructs either a model like x, or a gradient for it:
@@ -73,7 +79,6 @@ function _Tangent_biwalk(f, x, aux) # use with prune = true
7379
y isa Tuple{} && return NoT
7480
Tangent{typeof(x), typeof(y)}(y)
7581
end
76-
# _Tangent_biwalk(f, x::Tuple{}, aux) = NoT
7782

7883
function ChainRulesCore.rrule(::typeof(beta), x, off, flat; len)
7984
dflat = map!(zero, similar(flat, float(eltype(flat))), flat)

0 commit comments

Comments
 (0)