@@ -3,10 +3,10 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
33const NoT = NoTangent ()
44
55"""
6- destructure([T], model) -> vector, reconstructor
6+ destructure(model) -> vector, reconstructor
77
88Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
9- to a `Vector{T}` , and returns also a function which reverses this transformation.
9+ to a vector , and returns also a function which reverses this transformation.
1010Differentiable.
1111
1212# Example
@@ -18,8 +18,8 @@ julia> re([10,20,30])
1818(x = [10.0, 20.0], y = (sin, [30.0]))
1919```
2020"""
21- function destructure (:: Type{T} , x) where T
22- flat, off, len = alpha! (x, T[] )
21+ function destructure (x)
22+ flat, off, len = alpha (x )
2323 flat, Restucture (x, off, len)
2424end
2525
3232Base. show (io:: IO , re:: Restucture{T} ) where T = print (io, " Restructure(" , T. name. name, " , ..., " , re. length, " )" )
3333
3434# This flattens a model, and returns a web of offsets for later use:
35- function alpha! (x, flat:: AbstractVector )
36- isempty (flat) || error (" this won't work" )
37- isnumeric (x) && return append! (flat, x), 0 # trivial case
35+ function alpha (x)
36+ isnumeric (x) && return vcat (vec (x)), 0 , length (x) # trivial case
37+ arrays = AbstractVector[]
38+ len = Ref (0 )
3839 off = fmap (x; exclude = isnumeric, walk = (f, z) -> map (f, _trainable (z))) do y
39- append! (flat, y)
40- length (flat) - length (y)
40+ push! (arrays, vec (y))
41+ o = len[]
42+ len[] = o + length (y)
43+ o
4144 end
42- flat, off, length (flat)
45+ reduce (vcat, arrays), off, len[]
4346end
4447
45- function ChainRulesCore. rrule (:: typeof (alpha! ), x, flat )
46- flat′ , off, len = alpha! (x, flat )
47- alpha_back ((dflat, _)) = (NoT, beta (x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT )
48+ function ChainRulesCore. rrule (:: typeof (alpha), x)
49+ flat, off, len = alpha (x )
50+ alpha_back ((dflat, _)) = (NoT, beta (x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
4851 (flat, off, len), alpha_back
4952end
5053
@@ -100,14 +103,3 @@ function gamma!(x, dx, off::Integer, flat::AbstractVector)
100103end
101104gamma! (x, dx:: Zero , off, flat:: AbstractVector ) = nothing
102105gamma! (x, dx:: Zero , off:: Integer , flat:: AbstractVector ) = nothing # ambiguity
103-
104- # Least importantly, this infers the eltype if one is not given:
105- destructure (x) = destructure (omega (x), x)
106- function omega (x)
107- T = Bool
108- fmap (x; exclude = isnumeric, walk = (f, z) -> foreach (f, _trainable (z))) do y
109- T = promote_type (T, eltype (y))
110- end
111- T
112- end
113- ChainRulesCore. @non_differentiable omega (:: Any )
0 commit comments