@@ -2,64 +2,64 @@ module BlockSparseArraysTensorAlgebraExt
22
33using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
44using TensorAlgebra:
5- TensorAlgebra,
6- BlockedTrivialPermutation,
7- BlockedTuple,
8- FusionStyle,
9- ReshapeFusion,
10- fuseaxes
5+ TensorAlgebra,
6+ BlockedTrivialPermutation,
7+ BlockedTuple,
8+ FusionStyle,
9+ ReshapeFusion,
10+ fuseaxes
1111
1212struct BlockReshapeFusion <: FusionStyle end
1313
1414function TensorAlgebra. FusionStyle (:: Type{<:AbstractBlockSparseArray} )
15- return BlockReshapeFusion ()
15+ return BlockReshapeFusion ()
1616end
1717
1818using BlockArrays: Block, blocklength, blocks
1919using BlockSparseArrays: blocksparse
2020using SparseArraysBase: eachstoredindex
2121using TensorAlgebra: TensorAlgebra, matricize, unmatricize
2222function TensorAlgebra. matricize (
23- :: BlockReshapeFusion , a:: AbstractArray , biperm:: BlockedTrivialPermutation{2}
24- )
25- ax = fuseaxes (axes (a), biperm)
26- reshaped_blocks_a = reshape (blocks (a), map (blocklength, ax))
27- key (I) = Block (Tuple (I))
28- value (I) = matricize (reshaped_blocks_a[I], biperm)
29- Is = eachstoredindex (reshaped_blocks_a)
30- bs = if isempty (Is)
31- # Catch empty case and make sure the type is constrained properly.
32- # This seems to only be necessary in Julia versions below v1.11,
33- # try removing it when we drop support for those versions.
34- keytype = Base. promote_op (key, eltype (Is))
35- valtype = Base. promote_op (value, eltype (Is))
36- valtype′ = ! isconcretetype (valtype) ? AbstractMatrix{eltype (a)} : valtype
37- Dict {keytype,valtype′} ()
38- else
39- Dict (key (I) => value (I) for I in Is)
40- end
41- return blocksparse (bs, ax)
23+ :: BlockReshapeFusion , a:: AbstractArray , biperm:: BlockedTrivialPermutation{2}
24+ )
25+ ax = fuseaxes (axes (a), biperm)
26+ reshaped_blocks_a = reshape (blocks (a), map (blocklength, ax))
27+ key (I) = Block (Tuple (I))
28+ value (I) = matricize (reshaped_blocks_a[I], biperm)
29+ Is = eachstoredindex (reshaped_blocks_a)
30+ bs = if isempty (Is)
31+ # Catch empty case and make sure the type is constrained properly.
32+ # This seems to only be necessary in Julia versions below v1.11,
33+ # try removing it when we drop support for those versions.
34+ keytype = Base. promote_op (key, eltype (Is))
35+ valtype = Base. promote_op (value, eltype (Is))
36+ valtype′ = ! isconcretetype (valtype) ? AbstractMatrix{eltype (a)} : valtype
37+ Dict {keytype, valtype′} ()
38+ else
39+ Dict (key (I) => value (I) for I in Is)
40+ end
41+ return blocksparse (bs, ax)
4242end
4343
4444using BlockArrays: blocklengths
4545function TensorAlgebra. unmatricize (
46- :: BlockReshapeFusion ,
47- m:: AbstractMatrix ,
48- blocked_ax:: BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} ,
49- )
50- ax = Tuple (blocked_ax)
51- reshaped_blocks_m = reshape (blocks (m), map (blocklength, ax))
52- function f (I)
53- block_axes_I = BlockedTuple (
54- map (ntuple (identity, length (ax))) do i
55- return Base. axes1 (ax[i][Block (I[i])])
56- end ,
57- blocklengths (blocked_ax),
46+ :: BlockReshapeFusion ,
47+ m:: AbstractMatrix ,
48+ blocked_ax:: BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} ,
5849 )
59- return unmatricize (reshaped_blocks_m[I], block_axes_I)
60- end
61- bs = Dict (Block (Tuple (I)) => f (I) for I in eachstoredindex (reshaped_blocks_m))
62- return blocksparse (bs, ax)
50+ ax = Tuple (blocked_ax)
51+ reshaped_blocks_m = reshape (blocks (m), map (blocklength, ax))
52+ function f (I)
53+ block_axes_I = BlockedTuple (
54+ map (ntuple (identity, length (ax))) do i
55+ return Base. axes1 (ax[i][Block (I[i])])
56+ end ,
57+ blocklengths (blocked_ax),
58+ )
59+ return unmatricize (reshaped_blocks_m[I], block_axes_I)
60+ end
61+ bs = Dict (Block (Tuple (I)) => f (I) for I in eachstoredindex (reshaped_blocks_m))
62+ return blocksparse (bs, ax)
6363end
6464
6565end
0 commit comments