@@ -15,19 +15,51 @@ function TensorAlgebra.FusionStyle(::Type{<:AbstractBlockSparseArray})
1515 return BlockReshapeFusion ()
1616end
1717
18+ using BlockArrays: Block, blocklength, blocks
19+ using BlockSparseArrays: blocksparse
20+ using SparseArraysBase: eachstoredindex
21+ using TensorAlgebra: TensorAlgebra, matricize, unmatricize
1822function TensorAlgebra. matricize (
1923 :: BlockReshapeFusion , a:: AbstractArray , biperm:: BlockedTrivialPermutation{2}
2024)
21- new_axes = fuseaxes (axes (a), biperm)
22- return blockreshape (a, new_axes)
25+ ax = fuseaxes (axes (a), biperm)
26+ reshaped_blocks_a = reshape (blocks (a), map (blocklength, ax))
27+ key (I) = Block (Tuple (I))
28+ value (I) = matricize (reshaped_blocks_a[I], biperm)
29+ Is = eachstoredindex (reshaped_blocks_a)
30+ bs = if isempty (Is)
31+ # Catch empty case and make sure the type is constrained properly.
32+ # This seems to only be necessary in Julia versions below v1.11,
33+ # try removing it when we drop support for those versions.
34+ keytype = Base. promote_op (key, eltype (Is))
35+ valtype = Base. promote_op (value, eltype (Is))
36+ valtype′ = ! isconcretetype (valtype) ? AbstractMatrix{eltype (a)} : valtype
37+ Dict {keytype,valtype′} ()
38+ else
39+ Dict (key (I) => value (I) for I in Is)
40+ end
41+ return blocksparse (bs, ax)
2342end
2443
44+ using BlockArrays: blocklengths
2545function TensorAlgebra. unmatricize (
2646 :: BlockReshapeFusion ,
2747 m:: AbstractMatrix ,
28- blocked_axes :: BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} ,
48+ blocked_ax :: BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} ,
2949)
30- return blockreshape (m, Tuple (blocked_axes)... )
50+ ax = Tuple (blocked_ax)
51+ reshaped_blocks_m = reshape (blocks (m), map (blocklength, ax))
52+ function f (I)
53+ block_axes_I = BlockedTuple (
54+ map (ntuple (identity, length (ax))) do i
55+ return Base. axes1 (ax[i][Block (I[i])])
56+ end ,
57+ blocklengths (blocked_ax),
58+ )
59+ return unmatricize (reshaped_blocks_m[I], block_axes_I)
60+ end
61+ bs = Dict (Block (Tuple (I)) => f (I) for I in eachstoredindex (reshaped_blocks_m))
62+ return blocksparse (bs, ax)
3163end
3264
3365end
0 commit comments