Skip to content

Commit 5a7bfc8

Browse files
committed
replace append! with reduce(vcat, ...)
1 parent 868903a commit 5a7bfc8

File tree

2 files changed

+16
-26
lines changed

2 files changed

+16
-26
lines changed

src/destructure.jl

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
33
const NoT = NoTangent()
44

55
"""
6-
destructure([T], model) -> vector, reconstructor
6+
destructure(model) -> vector, reconstructor
77
88
Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
9-
to a `Vector{T}`, and returns also a function which reverses this transformation.
9+
to a vector, and returns also a function which reverses this transformation.
1010
Differentiable.
1111
1212
# Example
@@ -18,8 +18,8 @@ julia> re([10,20,30])
1818
(x = [10.0, 20.0], y = (sin, [30.0]))
1919
```
2020
"""
21-
function destructure(::Type{T}, x) where T
22-
flat, off, len = alpha!(x, T[])
21+
function destructure(x)
22+
flat, off, len = alpha(x)
2323
flat, Restucture(x, off, len)
2424
end
2525

@@ -32,19 +32,22 @@ end
3232
Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
3333

3434
# This flattens a model, and returns a web of offsets for later use:
35-
function alpha!(x, flat::AbstractVector)
36-
isempty(flat) || error("this won't work")
37-
isnumeric(x) && return append!(flat, x), 0 # trivial case
35+
function alpha(x)
36+
isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case
37+
arrays = AbstractVector[]
38+
len = Ref(0)
3839
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
39-
append!(flat, y)
40-
length(flat) - length(y)
40+
push!(arrays, vec(y))
41+
o = len[]
42+
len[] = o + length(y)
43+
o
4144
end
42-
flat, off, length(flat)
45+
reduce(vcat, arrays), off, len[]
4346
end
4447

45-
function ChainRulesCore.rrule(::typeof(alpha!), x, flat)
46-
flat, off, len = alpha!(x, flat)
47-
alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT)
48+
function ChainRulesCore.rrule(::typeof(alpha), x)
49+
flat, off, len = alpha(x)
50+
alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
4851
(flat, off, len), alpha_back
4952
end
5053

@@ -100,14 +103,3 @@ function gamma!(x, dx, off::Integer, flat::AbstractVector)
100103
end
101104
gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing
102105
gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
103-
104-
# Least importantly, this infers the eltype if one is not given:
105-
destructure(x) = destructure(omega(x), x)
106-
function omega(x)
107-
T = Bool
108-
fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y
109-
T = promote_type(T, eltype(y))
110-
end
111-
T
112-
end
113-
ChainRulesCore.@non_differentiable omega(::Any)

test/destructure.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ m6 = (a = m1, b = [4.0 + im], c = m1)
88
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
99

1010
@testset "flatten & restore" begin
11-
@test destructure(Int, m1)[1] isa Vector{Int}
1211
@test destructure(m1)[1] isa Vector{Float64}
13-
1412
@test destructure(m1)[1] == 1:3
1513
@test destructure(m2)[1] == 1:6
1614
@test destructure(m3)[1] == 1:6

0 commit comments

Comments
 (0)