diff --git a/Project.toml b/Project.toml index 6d636b7..0459c94 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" -version = "0.10.11" +version = "0.10.12" authors = ["ITensor developers and contributors"] [deps] @@ -44,7 +44,7 @@ MapBroadcast = "0.1.5" MatrixAlgebraKit = "0.6" SparseArraysBase = "0.7.1" SplitApplyCombine = "1.2.3" -TensorAlgebra = "0.3, 0.4" +TensorAlgebra = "0.5" TensorProducts = "0.1.7" Test = "1.10" TypeParameterAccessors = "0.4.1" diff --git a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl index 8a224f3..5efcc89 100644 --- a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl @@ -1,13 +1,7 @@ module BlockSparseArraysTensorAlgebraExt using BlockSparseArrays: AbstractBlockSparseArray, blockreshape -using TensorAlgebra: - TensorAlgebra, - BlockedTrivialPermutation, - BlockedTuple, - FusionStyle, - ReshapeFusion, - fuseaxes +using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, fuseaxes struct BlockReshapeFusion <: FusionStyle end @@ -20,12 +14,12 @@ using BlockSparseArrays: blocksparse using SparseArraysBase: eachstoredindex using TensorAlgebra: TensorAlgebra, matricize, unmatricize function TensorAlgebra.matricize( - ::BlockReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} + ::BlockReshapeFusion, a::AbstractArray, length1::Val, length2::Val ) - ax = fuseaxes(axes(a), biperm) + ax = fuseaxes(axes(a), length1, length2) reshaped_blocks_a = reshape(blocks(a), map(blocklength, ax)) key(I) = Block(Tuple(I)) - value(I) = matricize(reshaped_blocks_a[I], biperm) + value(I) = matricize(reshaped_blocks_a[I], length1, length2) Is = eachstoredindex(reshaped_blocks_a) bs = if isempty(Is) # Catch empty case and make sure the type is constrained properly. @@ -45,16 +39,17 @@ using BlockArrays: blocklengths function TensorAlgebra.unmatricize( ::BlockReshapeFusion, m::AbstractMatrix, - blocked_ax::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}}, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, ) - ax = Tuple(blocked_ax) + ax = (codomain_axes..., domain_axes...) reshaped_blocks_m = reshape(blocks(m), map(blocklength, ax)) function f(I) block_axes_I = BlockedTuple( map(ntuple(identity, length(ax))) do i return Base.axes1(ax[i][Block(I[i])]) end, - blocklengths(blocked_ax), + (length(codomain_axes), length(domain_axes)), ) return unmatricize(reshaped_blocks_m[I], block_axes_I) end diff --git a/test/Project.toml b/test/Project.toml index 65a615d..d3284cc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -20,6 +20,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" +[sources] +BlockSparseArrays = {path = ".."} + [compat] Adapt = "4" Aqua = "0.8" @@ -37,7 +40,7 @@ SafeTestsets = "0.1" SparseArraysBase = "0.7" StableRNGs = "1" Suppressor = "0.2" -TensorAlgebra = "0.3, 0.4" +TensorAlgebra = "0.5" Test = "1" TestExtras = "0.3" TypeParameterAccessors = "0.4"