Skip to content

Commit 756b450

Browse files
committed
arrays of arrays
1 parent 337f365 commit 756b450

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

src/destructure.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
102102
au, _ = functor(typeof(x), aux)
103103
y = _trainmap(f, ch, _trainable(x), au)
104104
y isa Tuple{} && return NoT
105-
Tangent{typeof(x), typeof(y)}(y)
105+
p = ProjectTo(x)
106+
if p isa ProjectTo # e.g. Array, NamedTuple
107+
p(y)
108+
else # p === identity for unknown structs
109+
Tangent{typeof(x), typeof(y)}(y)
110+
end
106111
end
107112

108113
function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)

src/interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ trainable(x) = functor(x)[1]
7171
_trainable(x) = _trainable(functor(x)[1], trainable(x))
7272
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
7373
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
74+
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
7475
function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
7576
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple"
7677
map(c -> c in tr ? c : nothing, ch)

test/destructure.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ m4 = (x = m1, y = m1, z = collect(4:6.0))
66
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
77
m6 = (a = m1, b = [4.0 + im], c = m1)
88
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
9+
m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
910

1011
@testset "flatten & rebuild" begin
1112
@test destructure(m1)[1] isa Vector{Float64}
@@ -31,12 +32,20 @@ m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0
3132
@test m6′.a === m6′.c
3233
@test m6′.b == [7 + 4im]
3334

35+
# struct, trainable
3436
@test destructure(m7)[1] == 1:3
3537
m7′ = destructure(m7)[2]([10,20,30])
3638
@test m7′.a == (sin, [10,20,30])
3739
@test m7′.b == (cos, [4,5,6])
3840
@test m7′.c == (tan, [7,8,9])
3941

42+
@test destructure(m8)[1] == 1:5
43+
m8′ = destructure(m8)[2](1:5)
44+
@test m8′[1].x === m8′[1].y
45+
@test m8′[2].b.y === false
46+
@test m8′[3][1] == [5.0]
47+
48+
# errors
4049
@test_throws Exception destructure(m7)[2]([10,20])
4150
@test_throws Exception destructure(m7)[2]([10,20,30,40])
4251
end
@@ -57,6 +66,11 @@ end
5766
@test g6.a isa Vector{Float64}
5867
@test g6.b == [0+im]
5968

69+
g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
70+
@test g8[1].x == [2,4,6]
71+
@test g8[2].b.x == [8]
72+
@test g8[3] == [[10.0]]
73+
6074
@testset "second derivative" begin
6175
@test_broken gradient([1,2,3.0]) do v
6276
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
@@ -90,6 +104,10 @@ end
90104
@test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
91105
@test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]
92106

107+
v8, re8 = destructure(m8)
108+
@test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
109+
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
110+
93111
@testset "second derivative" begin
94112
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
95113
@test_broken gradient(collect(1:6.0)) do y

0 commit comments

Comments
 (0)