11
2- using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
2+ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
33const NoT = NoTangent ()
44
55"""
@@ -11,11 +11,11 @@ Differentiable.
1111
1212# Example
1313```jldoctest
14- julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 ])))
15- ([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))
14+ julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im ])))
15+ (ComplexF64 [1.0 + 0.0im , 2.0 + 0.0im , 3.0 + 4.0im ], Restructure(NamedTuple, ..., 3))
1616
17- julia> re([10,20,30 ])
18- (x = [10 .0, 20 .0], y = (sin, [30.0 ]))
17+ julia> re([3, 5-im, 7+11im ])
18+ (x = [3 .0, 5 .0], y = (sin, ComplexF64[7.0 + 11.0im ]))
1919```
2020"""
2121function destructure (x)
2727 Restructure(Model, ..., length)
2828
2929This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
30- new parameters from vector `p`. If the model is callable, then `re(x, p)` .
30+ new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)` .
3131
3232# Example
3333```julia
@@ -107,22 +107,22 @@ end
107107
108108function ChainRulesCore. rrule (:: typeof (_rebuild), x, off, flat; len)
109109 dflat = map! (zero, similar (flat, float (eltype (flat))), flat)
110- _rebuild_back (dx) = (NoT, NoT, NoT, _accumulate ! (x, dx , off, dflat))
110+ _rebuild_back (dx) = (NoT, NoT, NoT, _grad ! (x, unthunk (dx) , off, dflat))
111111 _rebuild (x, off, flat; len), _rebuild_back
112112end
113113
114114# This is the gradient of model reconstruction, accumulating duplicates:
115- function _accumulate ! (x, dx, off, flat:: AbstractVector )
115+ function _grad ! (x, dx, off, flat:: AbstractVector )
116116 x′, _ = functor (typeof (x), x)
117117 dx′, _ = functor (typeof (x), dx)
118118 off′, _ = functor (typeof (x), off)
119- foreach ((xᵢ, dxᵢ, oᵢ) -> _accumulate ! (xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
119+ foreach ((xᵢ, dxᵢ, oᵢ) -> _grad ! (xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
120120 flat
121121end
122- function _accumulate ! (x, dx, off:: Integer , flat:: AbstractVector )
122+ function _grad ! (x, dx, off:: Integer , flat:: AbstractVector )
123123 @views flat[off .+ (1 : length (x))] .+ = dx # must visit all tied nodes
124124 flat
125125end
126- _accumulate ! (x, dx:: Zero , off, flat:: AbstractVector ) = nothing
127- _accumulate ! (x, dx:: Zero , off:: Integer , flat:: AbstractVector ) = nothing # ambiguity
126+ _grad ! (x, dx:: Zero , off, flat:: AbstractVector ) = nothing
127+ _grad ! (x, dx:: Zero , off:: Integer , flat:: AbstractVector ) = nothing # ambiguity
128128
0 commit comments