@@ -19,10 +19,12 @@ using SymmetrySectors: U1
1919using TensorAlgebra: fusedims, splitdims
2020using LinearAlgebra: adjoint
2121using Random: randn!
22- function blockdiagonal! (f, a:: AbstractArray )
23- for i in 1 : minimum (blocksize (a))
22+ function randn_blockdiagonal (elt:: Type , axes:: Tuple )
23+ a = BlockSparseArray {elt} (axes)
24+ blockdiaglength = minimum (blocksize (a))
25+ for i in 1 : blockdiaglength
2426 b = Block (ntuple (Returns (i), ndims (a)))
25- a[b] = f (a[b])
27+ a[b] = randn! (a[b])
2628 end
2729 return a
2830end
@@ -32,8 +34,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3234 @testset " map" begin
3335 d1 = gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 2 ])
3436 d2 = gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 2 ])
35- a = BlockSparseArray {elt} (d1, d2, d1, d2)
36- blockdiagonal! (randn!, a)
37+ a = randn_blockdiagonal (elt, (d1, d2, d1, d2))
3738 @test axes (a, 1 ) isa GradedOneTo
3839 @test axes (view (a, 1 : 4 , 1 : 4 , 1 : 4 , 1 : 4 ), 1 ) isa GradedOneTo
3940
@@ -89,8 +90,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
8990 @testset " fusedims" begin
9091 d1 = gradedrange ([U1 (0 ) => 1 , U1 (1 ) => 1 ])
9192 d2 = gradedrange ([U1 (0 ) => 1 , U1 (1 ) => 1 ])
92- a = BlockSparseArray {elt} (d1, d2, d1, d2)
93- blockdiagonal! (randn!, a)
93+ a = randn_blockdiagonal (elt, (d1, d2, d1, d2))
9494 m = fusedims (a, (1 , 2 ), (3 , 4 ))
9595 for ax in axes (m)
9696 @test ax isa GradedOneTo
@@ -107,6 +107,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
107107 @test a[2 , 2 , 2 , 2 ] == m[4 , 4 ]
108108 @test blocksize (m) == (3 , 3 )
109109 @test a == splitdims (m, (d1, d2), (d1, d2))
110+
111+ # check block fusing and splitting
112+ d = gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 1 ])
113+ a = randn_blockdiagonal (elt, (d, d, dual (d), dual (d)))
114+ @test splitdims (fusedims (a, (1 , 2 ), (3 , 4 )), axes (a)... ) == a
110115 end
111116
112117 @testset " dual axes" begin
0 commit comments