Skip to content

Commit 8a4e0c7

Browse files
authored
Fix broadcasting over sector and non-sector arrays (#66)
1 parent 3a057e4 commit 8a4e0c7

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.4.20"
4+
version = "0.4.21"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -26,7 +26,7 @@ GradedArraysTensorAlgebraExt = "TensorAlgebra"
2626
[compat]
2727
ArrayLayouts = "1"
2828
BlockArrays = "1.6"
29-
BlockSparseArrays = "0.8, 0.9.3"
29+
BlockSparseArrays = "0.9.5"
3030
Compat = "4.16"
3131
FillArrays = "1.13"
3232
HalfIntegers = "1.6"

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ function TensorAlgebra.unmatricize(
6262
m::AbstractMatrix,
6363
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
6464
)
65+
if isempty(blocked_axes)
66+
# Handle edge case of empty blocked_axes, which can occur
67+
# when matricizing a 0-dimensional array (a scalar).
68+
a = similar(m, ())
69+
a[] = only(m)
70+
return a
71+
end
72+
6573
# First, fuse axes to get `sectormergesortperm`.
6674
# Then unpermute the blocks.
6775
fused_axes = matricize_axes(blocked_axes)

src/sectorunitrange.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,10 @@ end
159159
function Base.reshape(A::AbstractArray, ax::Tuple{SectorOneTo,Vararg{SectorOneTo}})
160160
return reshape(A, ungrade.(ax))
161161
end
162+
163+
# Fixes issues when broadcasting over mixtures of arrays
164+
# where some have SectorOneTo axes and some have OneTo axes,
165+
# which can show up in BlockSparseArrays blockwise broadcasting.
166+
# See https://github.com/ITensor/GradedArrays.jl/pull/65.
167+
Base.Broadcast.axistype(r1::SectorOneTo, ::Base.OneTo) = r1
168+
Base.Broadcast.axistype(::Base.OneTo, r2::SectorOneTo) = r2

test/test_sectorunitrange.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ using TestExtras: @constinferred
4646
@test length(sr) == 6
4747
@test firstindex(sr) == 1
4848
@test lastindex(sr) == 6
49-
@test eltype(sr) === Int
49+
@test eltype(sr) Int
5050
@test step(sr) == 1
5151
@test eachindex(sr) == Base.oneto(6)
5252
@test only(axes(sr)) isa SectorOneTo
@@ -59,9 +59,9 @@ using TestExtras: @constinferred
5959
@test isnothing(iterate(sr, 6))
6060

6161
# Base.Slice
62-
@test axes(Base.Slice(sr)) === (sr,)
63-
@test Base.axes1(Base.Slice(sr)) === sr
64-
@test Base.unsafe_indices(Base.Slice(sr)) === (sr,)
62+
@test axes(Base.Slice(sr)) (sr,)
63+
@test Base.axes1(Base.Slice(sr)) sr
64+
@test Base.unsafe_indices(Base.Slice(sr)) (sr,)
6565

6666
@test sr == 1:6
6767
@test sr == sr
@@ -113,8 +113,8 @@ using TestExtras: @constinferred
113113
@test blockisequal(sr, sr)
114114

115115
# GradedUnitRanges interface
116-
@test sector_type(sr) === SU{3,2}
117-
@test sector_type(typeof(sr)) === SU{3,2}
116+
@test sector_type(sr) SU{3,2}
117+
@test sector_type(typeof(sr)) SU{3,2}
118118
@test sectors(sr) == [SU((1, 0))]
119119
@test sector_multiplicity(sr) == 2
120120
@test sector_multiplicities(sr) == [2]
@@ -138,7 +138,7 @@ using TestExtras: @constinferred
138138
end
139139
@test sr[2:3] == 2:3
140140
@test (@constinferred getindex(sr, 2:3)) isa UnitRange
141-
@test sr[Block(1)] === sr
141+
@test sr[Block(1)] sr
142142
@test_throws BlockBoundsError sr[Block(2)]
143143

144144
sr2 = (@constinferred getindex(sr, (:, 2)))
@@ -162,5 +162,10 @@ using TestExtras: @constinferred
162162
# Slice sector range with sector range
163163
sr1 = sectorrange(U1(1), 4)
164164
sr2 = sectorrange(U1(1), 3)
165-
@test sr1[sr2] === sr2
165+
@test sr1[sr2] sr2
166+
167+
sr = sectorrange(U1(1), 4)
168+
@test Broadcast.axistype(sr, sr) sr
169+
@test Broadcast.axistype(sr, Base.OneTo(4)) sr
170+
@test Broadcast.axistype(Base.OneTo(4), sr) sr
166171
end

0 commit comments

Comments
 (0)