Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.2.25"
version = "0.2.26"


[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
77 changes: 23 additions & 54 deletions src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Adapt: Adapt, WrappedArray
using Adapt: Adapt, WrappedArray, adapt
using ArrayLayouts: zero!
using BlockArrays:
BlockArrays,
Expand Down Expand Up @@ -337,60 +337,29 @@
return Array{eltype(a)}(a)
end

using SparseArraysBase: ReplacedUnstoredSparseArray

# Wraps a block sparse array but replaces the unstored values.
# This is used in printing in order to customize printing
# of zero/unstored values.
struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <:
AbstractBlockSparseArray{T,N}
parent::Parent
getunstoredblock::F
end
Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent
Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a))
function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray)
return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock)
end

# This is copied from `SparseArraysBase.jl` since it is not part
# of the public interface.
# Like `Char` but prints without quotes.
struct UnquotedChar <: AbstractChar
char::Char
function SparseArraysBase.isstored(

Check warning on line 340 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L340

Added line #L340 was not covered by tests
A::AnyAbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
bI = BlockIndex(findblockindex.(axes(A), I))
bA = blocks(A)
return isstored(bA, bI.I...) && isstored(bA[bI.I...], bI.α...)

Check warning on line 345 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L343-L345

Added lines #L343 - L345 were not covered by tests
end
Base.show(io::IO, c::UnquotedChar) = print(io, c.char)
Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c)

using FillArrays: Fill
struct GetUnstoredBlockShow{Axes}
axes::Axes
end
@inline function (f::GetUnstoredBlockShow)(
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
# TODO: Make sure this works for sparse or block sparse blocks, immutable
# blocks, diagonal blocks, etc.!
b_size = ntuple(ndims(a)) do d
return length(f.axes[d][Block(I[d])])
function Base.replace_in_print_matrix(

Check warning on line 348 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L348

Added line #L348 was not covered by tests
A::AnyAbstractBlockSparseArray{<:Any,2}, i::Integer, j::Integer, s::AbstractString
)
return isstored(A, i, j) ? s : Base.replace_with_centered_mark(s)

Check warning on line 351 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L351

Added line #L351 was not covered by tests
end

# attempt to catch things that wrap GPU arrays
function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray)
X_cpu = adapt(Array, X)
if typeof(X_cpu) === typeof(X) # prevent infinite recursion

Check warning on line 357 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L355-L357

Added lines #L355 - L357 were not covered by tests
# need to specify ndims to allow specialized code for vector/matrix
@allowscalar @invoke Base.print_array(

Check warning on line 359 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L359

Added line #L359 was not covered by tests
io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)}
)
else
Base.print_array(io, X_cpu)

Check warning on line 363 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L363

Added line #L363 was not covered by tests
end
return Fill(UnquotedChar('.'), b_size)
end
# TODO: Use `Base.to_indices`.
@inline function (f::GetUnstoredBlockShow)(
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
return f(a, Tuple(I)...)
end

# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function
# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`.
function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray)
summary(io, a)
isempty(a) && return nothing
print(io, ":")
println(io)
a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a)))
@allowscalar Base.print_array(io, a′)
return nothing
end
2 changes: 1 addition & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ arrayts = (Array, JLArray)
a = BlockSparseMatrix{elt,arrayt{elt,2}}([2, 2], [2, 2])
@allowscalar a[1, 2] = 12
@test sprint(show, "text/plain", a) ==
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ . .\n $(zero(eltype(a))) $(zero(eltype(a))) │ . .\n ───────────┼──────\n . .. .\n . .. ."
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ \n ⋅ ⋅ "
end
end
@testset "TypeParameterAccessors.position" begin
Expand Down
Loading