@@ -6,6 +6,7 @@ m4 = (x = m1, y = m1, z = collect(4:6.0))
66m5 = (a = (m3, true ), b = (m1, false ), c = (m4, true ))
77m6 = (a = m1, b = [4.0 + im], c = m1)
88m7 = TwoThirds ((sin, collect (1 : 3.0 )), (cos, collect (4 : 6.0 )), (tan, collect (7 : 9.0 )))
9+ m8 = [Foo (m1, m1), (a = true , b = Foo ([4.0 ], false ), c = ()), [[5.0 ]]]
910
1011@testset " flatten & rebuild" begin
1112 @test destructure (m1)[1 ] isa Vector{Float64}
@@ -31,12 +32,20 @@ m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0
3132 @test m6′. a === m6′. c
3233 @test m6′. b == [7 + 4im ]
3334
35+ # struct, trainable
3436 @test destructure (m7)[1 ] == 1 : 3
3537 m7′ = destructure (m7)[2 ]([10 ,20 ,30 ])
3638 @test m7′. a == (sin, [10 ,20 ,30 ])
3739 @test m7′. b == (cos, [4 ,5 ,6 ])
3840 @test m7′. c == (tan, [7 ,8 ,9 ])
3941
42+ @test destructure (m8)[1 ] == 1 : 5
43+ m8′ = destructure (m8)[2 ](1 : 5 )
44+ @test m8′[1 ]. x === m8′[1 ]. y
45+ @test m8′[2 ]. b. y === false
46+ @test m8′[3 ][1 ] == [5.0 ]
47+
48+ # errors
4049 @test_throws Exception destructure (m7)[2 ]([10 ,20 ])
4150 @test_throws Exception destructure (m7)[2 ]([10 ,20 ,30 ,40 ])
4251end
5766 @test g6. a isa Vector{Float64}
5867 @test g6. b == [0 + im]
5968
69+ g8 = gradient (m -> sum (abs2, destructure (m)[1 ]), m8)[1 ]
70+ @test g8[1 ]. x == [2 ,4 ,6 ]
71+ @test g8[2 ]. b. x == [8 ]
72+ @test g8[3 ] == [[10.0 ]]
73+
6074 @testset " second derivative" begin
6175 @test_broken gradient ([1 ,2 ,3.0 ]) do v
6276 sum (abs2, gradient (m -> sum (abs2, destructure (m)[1 ]), (v, [4 ,5 ,6.0 ]))[1 ][1 ])
90104 @test gradient (x -> re7 (x). b[2 ][2 ], rand (3 ))[1 ] == [0 ,0 ,0 ]
91105 @test gradient (x -> re7 (x). c[2 ][1 ], rand (3 ))[1 ] == [0 ,0 ,0 ]
92106
107+ v8, re8 = destructure (m8)
108+ @test gradient (x -> sum (abs2, re8 (x)[1 ]. y), v8)[1 ] == [2 ,4 ,6 ,0 ,0 ]
109+ @test gradient (x -> only (sum (re8 (x)[3 ]))^ 2 , v8)[1 ] == [0 ,0 ,0 ,0 ,10 ]
110+
93111 @testset " second derivative" begin
94112 # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
95113 @test_broken gradient (collect (1 : 6.0 )) do y
0 commit comments