Skip to content

Commit 17b57f0

Browse files
committed
make len positional, fix a bug
1 parent 6f3eefa commit 17b57f0

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/destructure.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct Restructure{T,S}
4848
offsets::S
4949
length::Int
5050
end
51-
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat; len = re.length)
51+
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length)
5252
(re::Restructure)(x, flat::AbstractVector) = re(flat)(x)
5353
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
5454
Base.length(re::Restructure) = re.length
@@ -69,13 +69,13 @@ end
6969

7070
function ChainRulesCore.rrule(::typeof(_flatten), x)
7171
flat, off, len = _flatten(x)
72-
_flatten_back((dflat, _)) = (NoT, _rebuild(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
72+
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT))
7373
(flat, off, len), _flatten_back
7474
end
7575

7676
# This reconstructs either a model like x, or a gradient for it:
77-
function _rebuild(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...)
78-
len == length(flat) || error("wrong length")
77+
function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trainable_biwalk, kw...)
78+
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
7979
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
8080
_getat(y, o, flat)
8181
end
@@ -105,12 +105,14 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
105105
Tangent{typeof(x), typeof(y)}(y)
106106
end
107107

108-
function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len)
109-
dflat = map!(zero, similar(flat, float(eltype(flat))), flat)
110-
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, dflat))
111-
_rebuild(x, off, flat; len), _rebuild_back
108+
function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
109+
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT)
110+
_rebuild(x, off, flat, len; kw...), _rebuild_back
112111
end
113112

113+
_zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad!
114+
ChainRulesCore.@non_differentiable _zero(x)
115+
114116
# This is the gradient of model reconstruction, accumulating duplicates:
115117
function _grad!(x, dx, off, flat::AbstractVector)
116118
x′, _ = functor(typeof(x), x)

0 commit comments

Comments
 (0)