22m1 = collect (1 : 3.0 )
33m2 = (collect (1 : 3.0 ), collect (4 : 6.0 ))
44m3 = (x = m1, y = sin, z = collect (4 : 6.0 ))
5+
56m4 = (x = m1, y = m1, z = collect (4 : 6.0 )) # tied
67m5 = (a = (m3, true ), b = (m1, false ), c = (m4, true ))
78m6 = (a = m1, b = [4.0 + im], c = m1)
9+
810m7 = TwoThirds ((sin, collect (1 : 3.0 )), (cos, collect (4 : 6.0 )), (tan, collect (7 : 9.0 )))
911m8 = [Foo (m1, m1), (a = true , b = Foo ([4.0 ], false ), c = ()), [[5.0 ]]]
1012
13+ mat = Float32[4 6 ; 5 7 ]
14+ m9 = (a = m1, b = mat, c = [mat, m1])
15+
1116@testset " flatten & rebuild" begin
1217 @test destructure (m1)[1 ] isa Vector{Float64}
1318 @test destructure (m1)[1 ] == 1 : 3
@@ -16,6 +21,7 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
1621 @test destructure (m4)[1 ] == 1 : 6
1722 @test destructure (m5)[1 ] == vcat (1 : 6 , 4 : 6 )
1823 @test destructure (m6)[1 ] == vcat (1 : 3 , 4 + im)
24+ @test destructure (m9)[1 ] == 1 : 7
1925
2026 @test destructure (m1)[2 ](7 : 9 ) == [7 ,8 ,9 ]
2127 @test destructure (m2)[2 ](4 : 9 ) == ([4 ,5 ,6 ], [7 ,8 ,9 ])
@@ -45,6 +51,10 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
4551 @test m8′[2 ]. b. y === false
4652 @test m8′[3 ][1 ] == [5.0 ]
4753
54+ m9′ = destructure (m9)[2 ](10 : 10 : 70 )
55+ @test m9′. b === m9′. c[1 ]
56+ @test m9′. b isa Matrix{Float32}
57+
4858 # errors
4959 @test_throws Exception destructure (m7)[2 ]([10 ,20 ])
5060 @test_throws Exception destructure (m7)[2 ]([10 ,20 ,30 ,40 ])
7181 @test g8[2 ]. b. x == [8 ]
7282 @test g8[3 ] == [[10.0 ]]
7383
84+ g9 = gradient (m -> sum (sqrt, destructure (m)[1 ]), m9)[1 ]
85+ @test g9. c === nothing
86+
7487 @testset " second derivative" begin
7588 @test gradient ([1 ,2 ,3.0 ]) do v
7689 sum (abs2, gradient (m -> sum (abs2, destructure (m)[1 ]), (v, [4 ,5 ,6.0 ]))[1 ][1 ])
119132 @test gradient (x -> sum (abs2, re8 (x)[1 ]. y), v8)[1 ] == [2 ,4 ,6 ,0 ,0 ]
120133 @test gradient (x -> only (sum (re8 (x)[3 ]))^ 2 , v8)[1 ] == [0 ,0 ,0 ,0 ,10 ]
121134
135+ re9 = destructure (m9)[2 ]
136+ @test gradient (x -> sum (abs2, re9 (x). c[1 ]), 1 : 7 )[1 ] == [0 ,0 ,0 , 8 ,10 ,12 ,14 ]
137+
122138 @testset " second derivative" begin
123139 @test_broken gradient (collect (1 : 6.0 )) do y
124140 sum (abs2, gradient (x -> sum (abs2, re2 (x)[1 ]), y)[1 ])
0 commit comments