|
1 | | -using Adapt: Adapt, WrappedArray |
| 1 | +using Adapt: Adapt, WrappedArray, adapt |
2 | 2 | using ArrayLayouts: zero! |
3 | 3 | using BlockArrays: |
4 | 4 | BlockArrays, |
@@ -337,60 +337,29 @@ function Base.Array(a::AnyAbstractBlockSparseArray) |
337 | 337 | return Array{eltype(a)}(a) |
338 | 338 | end |
339 | 339 |
|
340 | | -using SparseArraysBase: ReplacedUnstoredSparseArray |
341 | | - |
342 | | -# Wraps a block sparse array but replaces the unstored values. |
343 | | -# This is used in printing in order to customize printing |
344 | | -# of zero/unstored values. |
345 | | -struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <: |
346 | | - AbstractBlockSparseArray{T,N} |
347 | | - parent::Parent |
348 | | - getunstoredblock::F |
349 | | -end |
350 | | -Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent |
351 | | -Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a)) |
352 | | -function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray) |
353 | | - return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock) |
354 | | -end |
355 | | - |
356 | | -# This is copied from `SparseArraysBase.jl` since it is not part |
357 | | -# of the public interface. |
358 | | -# Like `Char` but prints without quotes. |
359 | | -struct UnquotedChar <: AbstractChar |
360 | | - char::Char |
| 340 | +function SparseArraysBase.isstored( |
| 341 | + A::AnyAbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N} |
| 342 | +) where {N} |
| 343 | + bI = BlockIndex(findblockindex.(axes(A), I)) |
| 344 | + bA = blocks(A) |
| 345 | + return isstored(bA, bI.I...) && isstored(bA[bI.I...], bI.α...) |
361 | 346 | end |
362 | | -Base.show(io::IO, c::UnquotedChar) = print(io, c.char) |
363 | | -Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c) |
364 | 347 |
|
365 | | -using FillArrays: Fill |
366 | | -struct GetUnstoredBlockShow{Axes} |
367 | | - axes::Axes |
368 | | -end |
369 | | -@inline function (f::GetUnstoredBlockShow)( |
370 | | - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} |
371 | | -) where {N} |
372 | | - # TODO: Make sure this works for sparse or block sparse blocks, immutable |
373 | | - # blocks, diagonal blocks, etc.! |
374 | | - b_size = ntuple(ndims(a)) do d |
375 | | - return length(f.axes[d][Block(I[d])]) |
| 348 | +function Base.replace_in_print_matrix( |
| 349 | + A::AnyAbstractBlockSparseArray{<:Any,2}, i::Integer, j::Integer, s::AbstractString |
| 350 | +) |
| 351 | + return isstored(A, i, j) ? s : Base.replace_with_centered_mark(s) |
| 352 | +end |
| 353 | + |
| 354 | +# attempt to catch things that wrap GPU arrays |
| 355 | +function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray) |
| 356 | + X_cpu = adapt(Array, X) |
| 357 | + if typeof(X_cpu) === typeof(X) # prevent infinite recursion |
| 358 | + # need to specify ndims to allow specialized code for vector/matrix |
| 359 | + @allowscalar @invoke Base.print_array( |
| 360 | + io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)} |
| 361 | + ) |
| 362 | + else |
| 363 | + Base.print_array(io, X_cpu) |
376 | 364 | end |
377 | | - return Fill(UnquotedChar('.'), b_size) |
378 | | -end |
379 | | -# TODO: Use `Base.to_indices`. |
380 | | -@inline function (f::GetUnstoredBlockShow)( |
381 | | - a::AbstractArray{<:Any,N}, I::CartesianIndex{N} |
382 | | -) where {N} |
383 | | - return f(a, Tuple(I)...) |
384 | | -end |
385 | | - |
386 | | -# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function |
387 | | -# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`. |
388 | | -function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray) |
389 | | - summary(io, a) |
390 | | - isempty(a) && return nothing |
391 | | - print(io, ":") |
392 | | - println(io) |
393 | | - a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a))) |
394 | | - @allowscalar Base.print_array(io, a′) |
395 | | - return nothing |
396 | 365 | end |
0 commit comments