|
| 1 | + |
| 2 | +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo |
| 3 | +const NoT = NoTangent() |
| 4 | + |
| 5 | +""" |
| 6 | + destructure([T], model) -> vector, reconstructor |
| 7 | +
|
| 8 | +Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model |
| 9 | +to a `Vector{T}`, and returns also a function which reverses this transformation. |
| 10 | +Differentiable. |
| 11 | +""" |
| 12 | +function destructure(::Type{T}, x) where T |
| 13 | + flat, off = alpha!(x, T[]) |
| 14 | + len = length(flat) |
| 15 | + # flat, newflat -> beta(x, off, newflat; len) |
| 16 | + flat, Restucture(x, off, len) |
| 17 | +end |
| 18 | + |
| 19 | +struct Restucture{T,S} |
| 20 | + model::T |
| 21 | + offsets::S |
| 22 | + length::Int |
| 23 | +end |
| 24 | +(re::Restucture)(flat) = beta(re.model, re.offsets, flat; len = re.length) |
| 25 | +Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") |
| 26 | + |
| 27 | +# This flattens a model, and returns a web of offsets for later use: |
| 28 | +function alpha!(x, flat::AbstractVector) |
| 29 | + isempty(flat) || error("this won't work") |
| 30 | + isnumeric(x) && return append!(flat, x), 0 # trivial case |
| 31 | + off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y |
| 32 | + append!(flat, y) |
| 33 | + length(flat) - length(y) |
| 34 | + end |
| 35 | + flat, off |
| 36 | +end |
| 37 | + |
| 38 | +function ChainRulesCore.rrule(::typeof(alpha!), x, flat) |
| 39 | + flat′, off = alpha!(x, flat) |
| 40 | + len = length(flat′) |
| 41 | + alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT) |
| 42 | + (flat′, off), alpha_back |
| 43 | +end |
| 44 | + |
| 45 | +# This reconstructs either a model like x, or a gradient for it: |
| 46 | +function beta(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...) |
| 47 | + len == length(flat) || error("wrong length") |
| 48 | + fmap(x, off; exclude = isnumeric, walk, kw...) do y, o |
| 49 | + _getat(y, o, flat) |
| 50 | + end |
| 51 | +end |
| 52 | + |
| 53 | +_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) |
| 54 | +_getat(y::AbstractArray, o::Int, flat::AbstractVector) = |
| 55 | + ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes |
| 56 | + |
| 57 | +function _trainable_biwalk(f, x, aux) |
| 58 | + ch, re = functor(typeof(x), x) |
| 59 | + au, _ = functor(typeof(x), aux) |
| 60 | + trainmap(f, ch, _trainable(x), au) |> re |
| 61 | +end |
| 62 | + |
| 63 | +function trainmap(f, ch, tr, aux) |
| 64 | + map(ch, tr, aux) do c, t, a |
| 65 | + isnothing(t) ? c : f(t, a) |
| 66 | + end |
| 67 | +end |
| 68 | + |
| 69 | +function _Tangent_biwalk(f, x, aux) # use with prune = true |
| 70 | + ch, re = functor(typeof(x), x) |
| 71 | + au, _ = functor(typeof(x), aux) |
| 72 | + y = trainmap(f, ch, _trainable(x), au) |
| 73 | + y isa Tuple{} && return NoT |
| 74 | + Tangent{typeof(x), typeof(y)}(y) |
| 75 | +end |
| 76 | +# _Tangent_biwalk(f, x::Tuple{}, aux) = NoT |
| 77 | + |
| 78 | +function ChainRulesCore.rrule(::typeof(beta), x, off, flat; len) |
| 79 | + dflat = map!(zero, similar(flat, float(eltype(flat))), flat) |
| 80 | + beta_back(dx) = (NoT, NoT, NoT, gamma!(x, dx, off, dflat)) |
| 81 | + beta(x, off, flat; len), beta_back |
| 82 | +end |
| 83 | + |
| 84 | +# This is the gradient of model reconstruction, accumulating duplicates: |
| 85 | +function gamma!(x, dx, off, flat::AbstractVector) |
| 86 | + x′, _ = functor(typeof(x), x) |
| 87 | + dx′, _ = functor(typeof(x), dx) |
| 88 | + off′, _ = functor(typeof(x), off) |
| 89 | + foreach((xᵢ, dxᵢ, oᵢ) -> gamma!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) |
| 90 | + flat |
| 91 | +end |
| 92 | +function gamma!(x, dx, off::Integer, flat::AbstractVector) |
| 93 | + @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes, hence no fmap. |
| 94 | + flat |
| 95 | +end |
| 96 | +gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing |
| 97 | +gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity |
| 98 | + |
| 99 | +# Least importantly, this infers the eltype if one is not given: |
| 100 | +destructure(x) = destructure(omega(x), x) |
| 101 | +function omega(x) |
| 102 | + T = Bool |
| 103 | + fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y |
| 104 | + T = promote_type(T, eltype(y)) |
| 105 | + end |
| 106 | + T |
| 107 | +end |
| 108 | +ChainRulesCore.@non_differentiable omega(::Any) |
0 commit comments