|
17 | 17 | blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)] |
18 | 18 | blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)] |
19 | 19 |
|
20 | | - @testset "$T" for (T, (b1, b2, b3)) in ( |
21 | | - Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))), |
22 | | - Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)), |
23 | | - ) |
24 | | - A = rand(rng, N, N + N1) |
25 | | - B = rand(rng, N + N1, N + N2) |
26 | | - A′, B′ = A', B' |
27 | | - a = rand(rng, N) |
28 | | - b = rand(rng, N + N1) |
| 20 | + @testset for V in (Tuple, Vector) |
| 21 | + b1 = BlockDiagonal(V(blocks1)) |
| 22 | + b2 = BlockDiagonal(V(blocks2)) |
| 23 | + N = size(b1, 1) |
29 | 24 |
|
30 | 25 | @testset "AbstractArray" begin |
31 | | - X = rand(2, 2); Y = rand(3, 3) |
| 26 | + X = rand(2, 2) |
| 27 | + Y = rand(3, 3) |
32 | 28 |
|
33 | 29 | @test size(b1) == (N, N) |
34 | 30 | @test size(b1, 1) == N && size(b1, 2) == N |
|
53 | 49 | end |
54 | 50 |
|
55 | 51 | @testset "parent" begin |
56 | | - @test parent(b1) isa Union{Tuple,AbstractVector} |
| 52 | + @test parent(b1) isa V |
57 | 53 | @test eltype(parent(b1)) <: AbstractMatrix |
58 | 54 | @test parent(BlockDiagonal([X, Y])) == [X, Y] |
59 | 55 | @test parent(BlockDiagonal((X, Y))) == (X, Y) |
|
66 | 62 | end |
67 | 63 |
|
68 | 64 | @testset "setindex!" begin |
69 | | - X = BlockDiagonal([rand(Float32, 5, 5), rand(Float32, 3, 3)]) |
| 65 | + X = BlockDiagonal(V([rand(Float32, 5, 5), rand(Float32, 3, 3)])) |
70 | 66 | X[10] = Int(10) |
71 | 67 | @test X[10] === Float32(10.0) |
72 | 68 | X[3, 3] = Int(9) |
|
78 | 74 |
|
79 | 75 | @testset "ChainRules" begin |
80 | 76 | @testset "BlockDiagonal" begin |
81 | | - x = [randn(1, 2), randn(2, 2)] |
82 | | - x̄ = [randn(1, 2), randn(2, 2)] |
83 | | - ȳ = Composite{typeof(BlockDiagonal(x))}(blocks=[randn(1, 2), randn(2, 2)]) |
| 77 | + x = V([randn(1, 2), randn(2, 2)]) |
| 78 | + x̄ = V([randn(1, 2), randn(2, 2)]) |
| 79 | + |
| 80 | + ȳ = Composite{typeof(BlockDiagonal(x))}(blocks=V([randn(1, 2), randn(2, 2)])) |
84 | 81 | rrule_test(BlockDiagonal, ȳ, (x, x̄)) |
85 | 82 | end |
86 | 83 | @testset "Matrix" begin |
87 | | - D = BlockDiagonal([randn(1, 2), randn(2, 2)]) |
88 | | - D̄ = Composite{typeof(D)}((blocks=[randn(1, 2), randn(2, 2)]), ) |
| 84 | + D = BlockDiagonal(V([randn(1, 2), randn(2, 2)])) |
| 85 | + D̄ = Composite{typeof(D)}((blocks=V([randn(1, 2), randn(2, 2)])),) |
89 | 86 | Ȳ = randn(size(D)) |
90 | 87 | rrule_test(Matrix, Ȳ, (D, D̄)) |
91 | 88 | end |
|
98 | 95 | end |
99 | 96 |
|
100 | 97 | @testset "blocks size" begin |
101 | | - B = BlockDiagonal([rand(3, 3), rand(4, 4)]) |
| 98 | + B = BlockDiagonal(V([rand(3, 3), rand(4, 4)])) |
102 | 99 | @test nblocks(B) == 2 |
103 | | - @test blocksizes(B) == [(3, 3), (4, 4)] |
| 100 | + @test blocksizes(B) == V([(3, 3), (4, 4)]) |
104 | 101 | @test blocksize(B, 2) == blocksizes(B)[2] == blocksize(B, 2, 2) |
105 | 102 | end |
106 | 103 |
|
|
124 | 121 | @testset "Non-Square Matrix" begin |
125 | 122 | A1 = ones(2, 4) |
126 | 123 | A2 = 2 * ones(3, 2) |
127 | | - B1 = BlockDiagonal([A1, A2]) |
128 | | - B2 = [A1 zeros(2, 2); zeros(3, 4) A2] |
| 124 | + B1 = BlockDiagonal(V([A1, A2])) |
| 125 | + B2 = [A1 zeros(2, 2); zeros(3, 4) A2] |
129 | 126 |
|
130 | 127 | @test B1 == B2 |
131 | 128 | # Dimension check |
|
0 commit comments