Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.7.14"
version = "0.7.15"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
20 changes: 19 additions & 1 deletion src/abstractblocksparsearray/abstractblocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value, ::Block{0})
return a
end

# Custom `_convert` works around the issue that
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
# https://github.com/JuliaLang/julia/pull/52487).
# TODO: Delete once we drop support for Julia v1.10.
_convert(::Type{T}, a::AbstractArray) where {T} = convert(T, a)
using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag
_construct(T::Type{<:Diagonal}, a::AbstractMatrix) = T(diag(a))
function _convert(T::Type{<:Diagonal}, a::AbstractMatrix)
LinearAlgebra.checksquare(a)
return isdiag(a) ? _construct(T, a) : throw(InexactError(:convert, T, a))
end

function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
) where {N}
Expand All @@ -74,7 +87,12 @@ function Base.setindex!(
),
)
end
blocks(a)[Int.(I)...] = value
# Custom `_convert` works around the issue that
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
# https://github.com/JuliaLang/julia/pull/52487).
# TODO: Delete once we drop support for Julia v1.10.
blocks(a)[Int.(I)...] = _convert(blocktype(a), value)
return a
end

Expand Down
34 changes: 29 additions & 5 deletions src/abstractblocksparsearray/arraylayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,37 @@ function ArrayLayouts.MemoryLayout(
end

function Base.similar(
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
mul::MulAdd{
<:BlockLayout{<:SparseLayout,BlockLayoutA},
<:BlockLayout{<:SparseLayout,BlockLayoutB},
LayoutC,
T,
A,
B,
C,
},
elt::Type,
axes,
) where {A,B}
# TODO: Use something like `Base.promote_op(*, A, B)` to determine the output block type.
output_blocktype = similartype(blocktype(A), Type{elt}, Tuple{blockaxistype.(axes)...})
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
) where {BlockLayoutA,BlockLayoutB,LayoutC,T,A,B,C}

# TODO: Consider using this instead:
# ```julia
# blockmultype = MulAdd{BlockLayoutA,BlockLayoutB,LayoutC,T,blocktype(A),blocktype(B),C}
# output_blocktype = Base.promote_op(
# similar, blockmultype, Type{elt}, Tuple{eltype.(eachblockaxis.(axes))...}
# )
# ```
# The issue is that it in some cases it seems to lose some information about the block types.

# TODO: Maybe this should be:
# output_blocktype = Base.promote_op(
# mul!, blocktype(mul.A), blocktype(mul.B), blocktype(mul.C), typeof(mul.α), typeof(mul.β)
# )

output_blocktype = Base.promote_op(*, blocktype(mul.A), blocktype(mul.B))
output_blocktype′ =
!isconcretetype(output_blocktype) ? AbstractMatrix{elt} : output_blocktype
return similar(BlockSparseArray{elt,length(axes),output_blocktype′}, axes)
end

# Materialize a SubArray view.
Expand Down
22 changes: 14 additions & 8 deletions src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,11 @@ end

function blocksparse_similar(a, elt::Type, axes::Tuple)
ndims = length(axes)
blockt = similartype(blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
return BlockSparseArray{elt,ndims,blockt}(undef, axes)
# TODO: Define a version of `similartype` that catches the case
# where the output isn't concrete and returns an `AbstractArray`.
blockt = Base.promote_op(similar, blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
blockt′ = !isconcretetype(blockt) ? AbstractArray{elt,ndims} : blockt
return BlockSparseArray{elt,ndims,blockt′}(undef, axes)
end
@interface ::AbstractBlockSparseArrayInterface function Base.similar(
a::AbstractArray, elt::Type, axes::Tuple{Vararg{Int}}
Expand Down Expand Up @@ -283,13 +286,11 @@ function Base.similar(
elt::Type,
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

# Fixes ambiguity error.
function Base.similar(a::AnyAbstractBlockSparseArray, elt::Type, axes::Tuple{})
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

Expand All @@ -301,7 +302,6 @@ function Base.similar(
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
},
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

Expand All @@ -311,7 +311,6 @@ function Base.similar(
elt::Type,
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

Expand All @@ -321,9 +320,17 @@ function Base.similar(
elt::Type,
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end
function Base.similar(a::AnyAbstractBlockSparseArray, elt::Type)
return @interface interface(a) similar(a, elt, axes(a))
end
function Base.similar(
a::AnyAbstractBlockSparseArray,
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
return @interface interface(a) similar(a, eltype(a), axes)
end

# Fixes ambiguity errors with BlockArrays.
function Base.similar(
Expand All @@ -343,7 +350,6 @@ end
function Base.similar(
a::AnyAbstractBlockSparseArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

Expand Down
8 changes: 8 additions & 0 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ function eachblockstoredindex(a::AbstractArray)
return Block.(Tuple.(eachstoredindex(blocks(a))))
end

function SparseArraysBase.isstored(a::AbstractArray, I1::Block{1}, Irest::Block{1}...)
I = (I1, Irest...)
return isstored(blocks(a), Int.(I)...)
end
function SparseArraysBase.isstored(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
return isstored(a, Tuple(I)...)
end

using DiagonalArrays: diagindices
# Block version of `DiagonalArrays.diagindices`.
function blockdiagindices(a::AbstractArray)
Expand Down
1 change: 1 addition & 0 deletions src/blocksparsearrayinterface/getunstoredblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ end
ax = ntuple(N) do d
return only(axes(f.axes[d][Block(I[d])]))
end
!isconcretetype(A) && return zero!(similar(Array{eltype(A),N}, ax))
return zero!(similar(A, ax))
end
@inline function (f::GetUnstoredBlock)(
Expand Down
55 changes: 38 additions & 17 deletions src/blocksparsearrayinterface/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,38 @@ function union_eachblockstoredindex(as::AbstractArray...)
return ∪(map(eachblockstoredindex, as)...)
end

# Get a view of a block assuming it is stored.
function viewblock_stored(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
return blocks(a)[Int.(I)...]
end
function viewblock_stored(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
return viewblock_stored(a, Tuple(I)...)
end

using FillArrays: Zeros
# Get a view of a block if it is stored, otherwise return a lazy zeros.
function viewblock_or_zeros(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
if isstored(a, I...)
return viewblock_stored(a, I...)
else
block_ax = map((ax, i) -> eachblockaxis(ax)[Int(i)], axes(a), I)
return Zeros{eltype(a)}(block_ax)
end
end
function viewblock_or_zeros(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
return viewblock_or_zeros(a, Tuple(I)...)
end

function map_block!(f, a_dest::AbstractArray, I::Block, a_srcs::AbstractArray...)
a_srcs_I = map(a_src -> viewblock_or_zeros(a_src, I), a_srcs)
if isstored(a_dest, I)
a_dest[I] .= f.(a_srcs_I...)
else
a_dest[I] = Broadcast.broadcast_preserving_zero_d(f, a_srcs_I...)
end
return a_dest
end

function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
# TODO: This assumes element types are numbers, generalize this logic.
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
Expand All @@ -27,22 +59,7 @@ function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
BlockRange(a_dest)
end
for I in Is
# TODO: Use:
# block_dest = @view a_dest[I]
# or:
# block_dest = @view! a_dest[I]
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(I))...]
# TODO: Use:
# block_srcs = map(a_src -> @view(a_src[I]), a_srcs)
block_srcs = map(a_srcs) do a_src
return blocks_maybe_single(a_src)[Int.(Tuple(I))...]
end
# TODO: Use `map!!` to handle immutable blocks.
map!(f, block_dest, block_srcs...)
# Replace the entire block, handles initializing new blocks
# or if blocks are immutable.
# TODO: Use `a_dest[I] = block_dest`.
blocks(a_dest)[Int.(Tuple(I))...] = block_dest
map_block!(f, a_dest, I, a_srcs...)
end
return a_dest
end
Expand Down Expand Up @@ -151,8 +168,12 @@ end
function map_stored_blocks(f, a::AbstractArray)
block_stored_indices = collect(eachblockstoredindex(a))
if isempty(block_stored_indices)
eltype_a′ = Base.promote_op(f, eltype(a))
blocktype_a′ = Base.promote_op(f, blocktype(a))
return BlockSparseArray{eltype(blocktype_a′),ndims(a),blocktype_a′}(undef, axes(a))
eltype_a′′ = !isconcretetype(eltype_a′) ? Any : eltype_a′
blocktype_a′′ =
!isconcretetype(blocktype_a′) ? AbstractArray{eltype_a′′,ndims(a)} : blocktype_a′
return BlockSparseArray{eltype_a′′,ndims(a),blocktype_a′′}(undef, axes(a))
end
stored_blocks = map(B -> f(@view!(a[B])), block_stored_indices)
blocktype_a′ = eltype(stored_blocks)
Expand Down
13 changes: 12 additions & 1 deletion src/factorizations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,21 @@ function MatrixAlgebraKit.default_svd_algorithm(
return BlockPermutedDiagonalAlgorithm(alg)
end

function output_type(
::typeof(svd_compact!),
A::Type{<:AbstractMatrix{T}},
Alg::Type{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {T}
USVᴴ = Base.promote_op(svd_compact!, A, Alg)
!isconcretetype(USVᴴ) &&
return Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
return USVᴴ
end

function similar_output(
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
)
BU, BS, BVᴴ = fieldtypes(Base.promote_op(svd_compact!, blocktype(A), typeof(alg.alg)))
BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A), typeof(alg.alg)))
U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1]))
S = similar(A, BlockType(BS), S_axes)
Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2)))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand Down
44 changes: 44 additions & 0 deletions test/test_abstract_blocktype.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using Adapt: adapt
using BlockArrays: Block
using BlockSparseArrays: BlockSparseMatrix, blockstoredlength
using JLArrays: JLArray
using SparseArraysBase: storedlength
using Test: @test, @test_broken, @testset

elts = (Float32, Float64, ComplexF32)
arrayts = (Array, JLArray)
@testset "Abstract block type (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts,
elt in elts

dev = adapt(arrayt)
a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3])
@test sprint(show, MIME"text/plain"(), a) isa String
@test iszero(storedlength(a))
@test iszero(blockstoredlength(a))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
a[Block(2, 2)] = dev(randn(elt, 3, 3))
@test !iszero(a[Block(1, 1)])
@test a[Block(1, 1)] isa arrayt{elt,2}
@test !iszero(a[Block(2, 2)])
@test a[Block(2, 2)] isa arrayt{elt,2}
@test iszero(a[Block(2, 1)])
@test a[Block(2, 1)] isa Matrix{elt}
@test iszero(a[Block(1, 2)])
@test a[Block(1, 2)] isa Matrix{elt}

b = copy(a)
@test Array(b) ≈ Array(a)

b = a + a
@test Array(b) ≈ Array(a) + Array(a)

b = 3a
@test Array(b) ≈ 3Array(a)

if arrayt === Array
b = a * a
@test Array(b) ≈ Array(a) * Array(a)
else
@test_broken a * a
end
end
Loading
Loading