Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.7.3"
version = "0.7.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
99 changes: 99 additions & 0 deletions src/abstractblocksparsearray/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,102 @@ function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix)
end
return tr_a
end

# TODO: Define `SparseArraysBase.isdiag`, define as
# `isdiag(blocks(a))`.
function blockisdiag(a::AbstractArray)
return all(eachblockstoredindex(a)) do I
return allequal(Tuple(I))
end
end

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

# Functions where the dense implementations in `LinearAlgebra` are
# not type stable.
const MATRIX_FUNCTIONS_UNSTABLE = [
:log,
:sqrt,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

function initialize_output_blocksparse(f::F, a::AbstractMatrix) where {F}
B = Base.promote_op(f, blocktype(a))
return similar(a, BlockType(B))
end

function matrix_function_blocksparse(f::F, a::AbstractMatrix; kwargs...) where {F}
blockisdiag(a) || throw(ArgumentError("`$f` only defined for block-diagonal matrices"))
fa = initialize_output_blocksparse(f, a)
for I in blockdiagindices(a)
fa[I] = f(a[I]; kwargs...)
end
return fa
end

for f in MATRIX_FUNCTIONS
@eval begin
function Base.$f(a::AnyAbstractBlockSparseMatrix)
return matrix_function_blocksparse($f, a)
end
end
end

for f in MATRIX_FUNCTIONS_UNSTABLE
@eval begin
function initialize_output_blocksparse(::typeof($f), a::AbstractMatrix)
B = similartype(blocktype(a), complex(eltype(a)))
return similar(a, BlockType(B))
end
end
end

function LinearAlgebra.inv(a::AnyAbstractBlockSparseMatrix)
return matrix_function_blocksparse(inv, a)
end

using LinearAlgebra: LinearAlgebra, pinv
function LinearAlgebra.pinv(a::AnyAbstractBlockSparseMatrix; kwargs...)
return matrix_function_blocksparse(pinv, a; kwargs...)
end
89 changes: 86 additions & 3 deletions test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
using BlockSparseArrays:
BlockSparseArray, BlockDiagonal, blockstoredlength, eachblockstoredindex
BlockSparseArrays,
BlockDiagonal,
BlockSparseArray,
BlockSparseMatrix,
blockstoredlength,
eachblockstoredindex
using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart, pinv
using MatrixAlgebraKit:
diagview,
eig_full,
Expand All @@ -22,10 +28,87 @@ using MatrixAlgebraKit:
svd_trunc,
truncrank,
trunctol
using LinearAlgebra: LinearAlgebra, Diagonal, hermitianpart
using Random: Random
using StableRNGs: StableRNG
using Test: @inferred, @testset, @test
using Test: @inferred, @test, @test_broken, @test_throws, @testset

@testset "Matrix functions (T=$elt)" for elt in (Float32, Float64, ComplexF64)
rng = StableRNG(123)
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
a[Block(1, 1)] = randn(rng, elt, 2, 2)
a[Block(2, 2)] = randn(rng, elt, 3, 3)
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
MATRIX_FUNCTIONS = [MATRIX_FUNCTIONS; [:inv, :pinv]]
# Only works when real, also isn't defined in Julia 1.10.
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
MATRIX_FUNCTIONS_LOW_ACCURACY = [:acoth]
for f in setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_LOW_ACCURACY)
@eval begin
fa = $f($a)
@test Matrix(fa) ≈ $f(Matrix($a)) rtol = √(eps(real($elt)))
@test fa isa BlockSparseMatrix
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
end
end
for f in MATRIX_FUNCTIONS_LOW_ACCURACY
@eval begin
fa = $f($a)
if !Sys.isapple() && isreal($elt)
# `acoth` appears to be broken on this matrix on Windows and Ubuntu
# for real matrices.
@test_broken Matrix(fa) ≈ $f(Matrix($a)) rtol = √eps(real($elt))
else
@test Matrix(fa) ≈ $f(Matrix($a)) rtol = √eps(real($elt))
end
@test fa isa BlockSparseMatrix
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
end
end

# Catch case of off-diagonal blocks.
rng = StableRNG(123)
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
a[Block(1, 1)] = randn(rng, elt, 2, 2)
a[Block(1, 2)] = randn(rng, elt, 2, 3)
for f in MATRIX_FUNCTIONS
@eval begin
@test_throws ArgumentError $f($a)
end
end

# Missing diagonal blocks.
rng = StableRNG(123)
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
a[Block(2, 2)] = randn(rng, elt, 3, 3)
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
# These functions involve inverses so they break when there are zeros on the diagonal.
MATRIX_FUNCTIONS_SINGULAR = [
:log, :acsc, :asec, :acot, :acsch, :asech, :acoth, :csc, :cot, :csch, :coth
]
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_SINGULAR)
# Dense version is broken for some reason, investigate.
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
for f in MATRIX_FUNCTIONS
@eval begin
fa = $f($a)
@test Matrix(fa) ≈ $f(Matrix($a)) rtol = √(eps(real($elt)))
@test fa isa BlockSparseMatrix
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
end
end

SINGULAR_EXCEPTION = if VERSION < v"1.11-"
# A different exception is thrown in older versions of Julia.
LinearAlgebra.LAPACKException
else
LinearAlgebra.SingularException
end
for f in setdiff(MATRIX_FUNCTIONS_SINGULAR, [:log])
@eval begin
@test_throws $SINGULAR_EXCEPTION $f($a)
end
end
end

function test_svd(a, (U, S, Vᴴ); full=false)
# Check that the SVD is correct
Expand Down
Loading