@@ -48,7 +48,7 @@ struct Restructure{T,S}
4848 offsets:: S
4949 length:: Int
5050end
51- (re:: Restructure )(flat:: AbstractVector ) = _rebuild (re. model, re. offsets, flat; len = re. length)
51+ (re:: Restructure )(flat:: AbstractVector ) = _rebuild (re. model, re. offsets, flat, re. length)
5252(re:: Restructure )(x, flat:: AbstractVector ) = re (flat)(x)
5353Base. show (io:: IO , re:: Restructure{T} ) where T = print (io, " Restructure(" , T. name. name, " , ..., " , re. length, " )" )
5454Base. length (re:: Restructure ) = re. length
6969
7070function ChainRulesCore. rrule (:: typeof (_flatten), x)
7171 flat, off, len = _flatten (x)
72- _flatten_back ((dflat, _)) = (NoT, _rebuild (x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len ))
72+ _flatten_back ((dflat, _, _ )) = (NoT, _rebuild (x, off, dflat, len ; walk = _Tangent_biwalk, prune = NoT))
7373 (flat, off, len), _flatten_back
7474end
7575
7676# This reconstructs either a model like x, or a gradient for it:
77- function _rebuild (x, off, flat:: AbstractVector ; len, walk = _trainable_biwalk, kw... )
78- len == length (flat) || error ( " wrong length" )
77+ function _rebuild (x, off, flat:: AbstractVector , len = length (flat); walk = _trainable_biwalk, kw... )
78+ len == length (flat) || throw ( DimensionMismatch ( " Rebuild expected a vector of length $len , got $( length (flat)) " ) )
7979 fmap (x, off; exclude = isnumeric, walk, kw... ) do y, o
8080 _getat (y, o, flat)
8181 end
@@ -105,12 +105,14 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
105105 Tangent {typeof(x), typeof(y)} (y)
106106end
107107
108- function ChainRulesCore. rrule (:: typeof (_rebuild), x, off, flat; len)
109- dflat = map! (zero, similar (flat, float (eltype (flat))), flat)
110- _rebuild_back (dx) = (NoT, NoT, NoT, _grad! (x, unthunk (dx), off, dflat))
111- _rebuild (x, off, flat; len), _rebuild_back
108+ function ChainRulesCore. rrule (:: typeof (_rebuild), x, off, flat, len; kw... )
109+ _rebuild_back (dx) = (NoT, NoT, NoT, _grad! (x, unthunk (dx), off, _zero (flat)), NoT)
110+ _rebuild (x, off, flat, len; kw... ), _rebuild_back
112111end
113112
113+ _zero (x) = map! (zero, similar (x, float (eltype (x))), x) # mutable zero array for _grad!
114+ ChainRulesCore. @non_differentiable _zero (x)
115+
114116# This is the gradient of model reconstruction, accumulating duplicates:
115117function _grad! (x, dx, off, flat:: AbstractVector )
116118 x′, _ = functor (typeof (x), x)
0 commit comments