Skip to content

Commit 98723df

Browse files
authored
Structured broadcasting for UpperHessenberg (#1325)
This adds a custom structured broadcasting style for an `UpperHessenberg`, which now retains structure on some broadcasting operations. ```julia julia> UH = UpperHessenberg(ones(4,4)) 4×4 UpperHessenberg{Float64, Matrix{Float64}}: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ⋅ 1.0 1.0 1.0 ⋅ ⋅ 1.0 1.0 julia> UH .+ UH 4×4 UpperHessenberg{Float64, Matrix{Float64}}: 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 ⋅ 2.0 2.0 2.0 ⋅ ⋅ 2.0 2.0 julia> B = Bidiagonal(1:4, 1:3, :L); julia> UH .+ B 4×4 UpperHessenberg{Float64, Matrix{Float64}}: 2.0 1.0 1.0 1.0 2.0 3.0 1.0 1.0 ⋅ 3.0 4.0 1.0 ⋅ ⋅ 4.0 5.0 ``` Unlike an `UpperTriangular`, an `UpperHessenberg` can retain its structure in broadcasting operations involving other banded matrices such an `Tridiagonal`. We may also, in the future, make `::UpperTriangular .+ ::Bidiagonal` produce an `UpperHessenberg`.
1 parent d568106 commit 98723df

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

src/hessenberg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ size(H::UpperHessenberg) = size(H.data)
5959
axes(H::UpperHessenberg) = axes(H.data)
6060
parent(H::UpperHessenberg) = H.data
6161

62+
upperhessenbergdata(H::UpperHessenberg) = H.data
63+
upperhessenbergdata(A) = A
64+
6265
# similar behaves like UpperTriangular
6366
similar(H::UpperHessenberg, ::Type{T}) where {T} = UpperHessenberg(similar(H.data, T))
6467
similar(H::UpperHessenberg, ::Type{T}, dims::Dims{N}) where {T,N} = similar(H.data, T, dims)

src/structuredbroadcast.jl

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
88
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
99
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()
1010

11-
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12-
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
11+
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},
12+
LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T},
13+
UpperHessenberg{T}}
14+
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,
15+
LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular,
16+
UpperHessenberg)
1317
@eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}())
1418
end
1519

@@ -27,28 +31,46 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixSt
2731
StructuredMatrixStyle{LowerTriangular}()
2832
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) =
2933
StructuredMatrixStyle{UpperTriangular}()
34+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
35+
StructuredMatrixStyle{UpperHessenberg}()
3036

3137
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{Diagonal}) =
3238
StructuredMatrixStyle{Bidiagonal}()
3339
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3440
StructuredMatrixStyle{Tridiagonal}()
41+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Bidiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
42+
StructuredMatrixStyle{UpperHessenberg}()
43+
3544
Broadcast.BroadcastStyle(::StructuredMatrixStyle{SymTridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3645
StructuredMatrixStyle{Tridiagonal}()
46+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{SymTridiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
47+
StructuredMatrixStyle{UpperHessenberg}()
3748
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Tridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) =
3849
StructuredMatrixStyle{Tridiagonal}()
50+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{Tridiagonal}, ::StructuredMatrixStyle{UpperHessenberg}) =
51+
StructuredMatrixStyle{UpperHessenberg}()
3952

4053
Broadcast.BroadcastStyle(::StructuredMatrixStyle{LowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) =
4154
StructuredMatrixStyle{LowerTriangular}()
4255
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) =
4356
StructuredMatrixStyle{UpperTriangular}()
57+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperTriangular}, ::StructuredMatrixStyle{UpperHessenberg}) =
58+
StructuredMatrixStyle{UpperHessenberg}()
4459
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitLowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) =
4560
StructuredMatrixStyle{LowerTriangular}()
4661
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitUpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) =
4762
StructuredMatrixStyle{UpperTriangular}()
63+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{UnitUpperTriangular}, ::StructuredMatrixStyle{UpperHessenberg}) =
64+
StructuredMatrixStyle{UpperHessenberg}()
65+
66+
function Broadcast.BroadcastStyle(::StructuredMatrixStyle{UpperHessenberg},
67+
::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,UnitUpperTriangular,UpperTriangular}})
68+
StructuredMatrixStyle{UpperHessenberg}()
69+
end
4870

49-
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) =
71+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg}}) =
5072
StructuredMatrixStyle{Matrix}()
51-
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) =
73+
Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg}}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) =
5274
StructuredMatrixStyle{Matrix}()
5375

5476
# Make sure that `StructuredMatrixStyle{Matrix}` doesn't ever end up falling
@@ -97,7 +119,7 @@ function structured_broadcast_alloc(bc, ::Type{Tridiagonal},
97119
Tridiagonal(Array{ElType}(undef, n1),Array{ElType}(undef, n),Array{ElType}(undef, n1))
98120
end
99121
function structured_broadcast_alloc(bc, ::Type{T}, ::Type{ElType},
100-
sz::NTuple{2,Integer}) where {ElType,T<:UpperOrLowerTriangular}
122+
sz::NTuple{2,Integer}) where {ElType,T<:Union{UpperOrLowerTriangular, UpperHessenberg}}
101123
T(Array{ElType}(undef, sz))
102124
end
103125
structured_broadcast_alloc(bc, ::Type{Matrix}, ::Type{ElType}, sz::NTuple{2,Integer}) where {ElType} =
@@ -278,6 +300,7 @@ function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
278300
end
279301
_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A)
280302
_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A)
303+
_preprocess_broadcasted(::Type{UpperHessenberg}, A) = upperhessenbergdata(A)
281304

282305
function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
283306
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
@@ -305,6 +328,19 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
305328
return dest
306329
end
307330

331+
function copyto!(dest::UpperHessenberg, bc::Broadcasted{<:StructuredMatrixStyle})
332+
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
333+
axs = axes(dest)
334+
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
335+
bc_unwrapped = preprocess_broadcasted(UpperHessenberg, bc)
336+
for j in axs[2]
337+
for i in 1:min(size(dest.data,1), j+1)
338+
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
339+
end
340+
end
341+
return dest
342+
end
343+
308344
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
309345
function map(f, A::StructuredMatrix, Bs::StructuredMatrix...)
310346
sz = size(A)

test/structuredbroadcast.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ using Main.LinearAlgebraTestHelpers.SizedArrays
2424
S = SymTridiagonal(rand(N), rand(max(0,N-1)))
2525
U = UpperTriangular(rand(N,N))
2626
L = LowerTriangular(rand(N,N))
27+
UH = UpperHessenberg(rand(N,N))
2728
M = Matrix(rand(N,N))
28-
structuredarrays = (D, B, T, U, L, M, S)
29+
structuredarrays = (D, B, T, U, L, M, S, UH)
2930
fstructuredarrays = map(Array, structuredarrays)
3031
@testset "$(nameof(typeof(X)))" for (X, fX) in zip(structuredarrays, fstructuredarrays)
3132
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
@@ -137,6 +138,7 @@ end
137138
T = Tridiagonal(rand(max(0,N-1)), rand(N), rand(max(0,N-1)))
138139
= LowerTriangular(rand(N,N))
139140
= UpperTriangular(rand(N,N))
141+
UH = UpperHessenberg(rand(N,N))
140142
M = Matrix(rand(N,N))
141143

142144
@test broadcast!(sin, copy(D), D)::Diagonal == sin.(D)::Diagonal
@@ -145,13 +147,15 @@ end
145147
@test broadcast!(sin, copy(T), T)::Tridiagonal == sin.(T)::Tridiagonal
146148
@test broadcast!(sin, copy(◣), ◣)::LowerTriangular == sin.(◣)::LowerTriangular
147149
@test broadcast!(sin, copy(◥), ◥)::UpperTriangular == sin.(◥)::UpperTriangular
150+
@test broadcast!(sin, copy(UH), UH)::UpperHessenberg == sin.(UH)::UpperHessenberg
148151
@test broadcast!(sin, copy(M), M)::Matrix == sin.(M)::Matrix
149152
@test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A))
150153
@test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
151154
@test broadcast!(*, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
152155
@test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
153156
@test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
154157
@test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
158+
@test broadcast!(*, copy(UH), UH, A) == UpperHessenberg(broadcast(*, UH, A))
155159
@test broadcast!(*, copy(M), M, A) == Matrix(broadcast(*, M, A))
156160

157161
if N > 2
@@ -183,8 +187,9 @@ end
183187
S = SymTridiagonal(rand(N), rand(N - 1))
184188
U = UpperTriangular(rand(N,N))
185189
L = LowerTriangular(rand(N,N))
190+
UH = UpperHessenberg(rand(N,N))
186191
M = Matrix(rand(N,N))
187-
structuredarrays = (M, D, B, T, S, U, L)
192+
structuredarrays = (M, D, B, T, S, U, L, UH)
188193
fstructuredarrays = map(Array, structuredarrays)
189194
for (X, fX) in zip(structuredarrays, fstructuredarrays)
190195
@test (Q = map(sin, X); typeof(Q) == typeof(X) && Q == map(sin, fX))
@@ -398,4 +403,11 @@ end
398403
end
399404
end
400405

406+
@testset "Rectangular UpperHessenberg" begin
407+
UH = UpperHessenberg(ones(4,3))
408+
UH2 = UH .+ UH .- UH
409+
@test UH2 == UH
410+
@test UH2 isa UpperHessenberg
411+
end
412+
401413
end

0 commit comments

Comments
 (0)