Skip to content

Commit a2358df

Browse files
Merge remote-tracking branch 'upstream/master'
2 parents fbf9475 + 7b21cab commit a2358df

File tree

13 files changed

+186
-88
lines changed

13 files changed

+186
-88
lines changed

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ LinearAlgebra.hermitianpart
604604
LinearAlgebra.hermitianpart!
605605
LinearAlgebra.copy_adjoint!
606606
LinearAlgebra.copy_transpose!
607+
LinearAlgebra.uplo
607608
```
608609

609610
## Low-level matrix operations

src/LinearAlgebra.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ public AbstractTriangular,
181181
zeroslike,
182182
matprod_dest,
183183
fillstored!,
184-
fillband!
184+
fillband!,
185+
uplo
185186

186187
const BlasFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
187188
const BlasReal = Union{Float64,Float32}
@@ -363,7 +364,14 @@ function char_uplo(uplo::Symbol)
363364
end
364365
end
365366

367+
"""
368+
sym_uplo(uplo::Char)
369+
370+
Return the `Symbol` corresponding the `uplo` by checking for validity.
371+
"""
366372
function sym_uplo(uplo::Char)
373+
# This method is called by other packages, and isn't used within LinearAlgebra
374+
# It's retained here for backward compatibility.
367375
if uplo == 'U'
368376
return :U
369377
elseif uplo == 'L'
@@ -372,6 +380,13 @@ function sym_uplo(uplo::Char)
372380
throw_uplo()
373381
end
374382
end
383+
"""
384+
_sym_uplo(uplo::Char)
385+
386+
Return the `Symbol` corresponding to `uplo` without checking for validity.
387+
See also `sym_uplo`, which checks for validity.
388+
"""
389+
_sym_uplo(uplo::Char) = uplo == 'U' ? (:U) : (:L)
375390

376391
@noinline throw_uplo() = throw(ArgumentError("uplo argument must be either :U (upper) or :L (lower)"))
377392

src/bidiag.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ Bidiagonal(A::Bidiagonal) = A
115115
Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A
116116
Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo)
117117

118+
"""
119+
LinearAlgebra.uplo(S::Bidiagonal)::Symbol
120+
121+
Return a `Symbol` corresponding to whether the upper (`:U`) or lower (`:L`) off-diagonal band is stored.
122+
"""
123+
uplo(B::Bidiagonal) = sym_uplo(B.uplo)
124+
118125
_offdiagind(uplo) = uplo == 'U' ? 1 : -1
119126

120127
@inline function Base.isassigned(A::Bidiagonal, i::Int, j::Int)
@@ -295,7 +302,7 @@ function show(io::IO, M::Bidiagonal)
295302
print(io, ", ")
296303
show(io, M.ev)
297304
print(io, ", ")
298-
show(io, sym_uplo(M.uplo))
305+
show(io, _sym_uplo(M.uplo))
299306
print(io, ")")
300307
end
301308

@@ -910,12 +917,12 @@ function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
910917
@inbounds begin
911918
# first row of C
912919
for j in 1:min(2, n)
913-
C[1,j] += _add(A[1,j]*B[j,j])
920+
C[1,j] += _add(A[1,j]*Bd[j])
914921
end
915922
# second row of C
916923
if n > 1
917924
for j in 1:min(3, n)
918-
C[2,j] += _add(A[2,j]*B[j,j])
925+
C[2,j] += _add(A[2,j]*Bd[j])
919926
end
920927
end
921928
for j in 3:n-2
@@ -926,13 +933,13 @@ function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
926933
if n > 3
927934
# row before last of C
928935
for j in n-2:n
929-
C[n-1,j] += _add(A[n-1,j]*B[j,j])
936+
C[n-1,j] += _add(A[n-1,j]*Bd[j])
930937
end
931938
end
932939
# last row of C
933940
if n > 2
934941
for j in n-1:n
935-
C[n,j] += _add(A[n,j]*B[j,j])
942+
C[n,j] += _add(A[n,j]*Bd[j])
936943
end
937944
end
938945
end # inbounds

src/bunchkaufman.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ function bunchkaufman!(A::StridedMatrix{<:BlasFloat}, rook::Bool = false; check:
130130
end
131131

132132
bkcopy_oftype(A, S) = eigencopy_oftype(A, S)
133-
bkcopy_oftype(A::Symmetric{<:Complex}, S) = Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
133+
function bkcopy_oftype(A::Symmetric{<:Complex}, S)
134+
Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), _sym_uplo(A.uplo))
135+
end
134136

135137
"""
136138
bunchkaufman(A, rook::Bool=false; check = true) -> S::BunchKaufman

src/generic.jl

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -579,20 +579,14 @@ function generic_norm2(x)
579579
T = typeof(maxabs)
580580
if isfinite(length(x)*maxabs*maxabs) && !iszero(maxabs*maxabs) # Scaling not necessary
581581
sum::promote_type(Float64, T) = norm_sqr(v)
582-
while true
583-
y = iterate(x, s)
584-
y === nothing && break
585-
(v, s) = y
582+
for v in Iterators.rest(x, s)
586583
sum += norm_sqr(v)
587584
end
588585
ismissing(sum) && return missing
589586
return convert(T, sqrt(sum))
590587
else
591588
sum = abs2(norm(v)/maxabs)
592-
while true
593-
y = iterate(x, s)
594-
y === nothing && break
595-
(v, s) = y
589+
for v in Iterators.rest(x, s)
596590
sum += (norm(v)/maxabs)^2
597591
end
598592
ismissing(sum) && return missing
@@ -614,21 +608,15 @@ function generic_normp(x, p)
614608
spp::promote_type(Float64, T) = p
615609
if -1 <= p <= 1 || (isfinite(length(x)*maxabs^spp) && !iszero(maxabs^spp)) # scaling not necessary
616610
sum::promote_type(Float64, T) = norm(v)^spp
617-
while true
618-
y = iterate(x, s)
619-
y === nothing && break
620-
(v, s) = y
611+
for v in Iterators.rest(x, s)
621612
ismissing(v) && return missing
622613
sum += norm(v)^spp
623614
end
624615
return convert(T, sum^inv(spp))
625616
else # rescaling
626617
sum = (norm(v)/maxabs)^spp
627618
ismissing(sum) && return missing
628-
while true
629-
y = iterate(x, s)
630-
y === nothing && break
631-
(v, s) = y
619+
for v in Iterators.rest(x, s)
632620
ismissing(v) && return missing
633621
sum += (norm(v)/maxabs)^spp
634622
end
@@ -996,7 +984,8 @@ function dot(x, y) # arbitrary iterables
996984
return s
997985
end
998986

999-
dot(x::Number, y::Number) = conj(x) * y
987+
# the unary + is for type promotion in the Boolean case, mimicking the reduction in usual dot
988+
dot(x::Number, y::Number) = +(conj(x) * y)
1000989

1001990
function dot(x::AbstractArray, y::AbstractArray)
1002991
lx = length(x)

src/special.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,19 @@ end
292292

293293
for f in (:+, :-)
294294
@eval function $f(D::Diagonal{<:Number}, S::Symmetric)
295-
uplo = sym_uplo(S.uplo)
295+
uplo = _sym_uplo(S.uplo)
296296
return Symmetric(parentof_applytri($f, Symmetric(D, uplo), S), uplo)
297297
end
298298
@eval function $f(S::Symmetric, D::Diagonal{<:Number})
299-
uplo = sym_uplo(S.uplo)
299+
uplo = _sym_uplo(S.uplo)
300300
return Symmetric(parentof_applytri($f, S, Symmetric(D, uplo)), uplo)
301301
end
302302
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
303-
uplo = sym_uplo(H.uplo)
303+
uplo = _sym_uplo(H.uplo)
304304
return Hermitian(parentof_applytri($f, Hermitian(D, uplo), H), uplo)
305305
end
306306
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
307-
uplo = sym_uplo(H.uplo)
307+
uplo = _sym_uplo(H.uplo)
308308
return Hermitian(parentof_applytri($f, H, Hermitian(D, uplo)), uplo)
309309
end
310310
end
@@ -608,8 +608,8 @@ end
608608
# tridiagonal cholesky factorization
609609
function cholesky(S::RealSymHermitian{<:BiTriSym}, ::NoPivot = NoPivot(); check::Bool = true)
610610
T = choltype(S)
611-
B = Bidiagonal{T}(diag(S, 0), diag(S, S.uplo == 'U' ? 1 : -1), sym_uplo(S.uplo))
612-
cholesky!(Hermitian(B, sym_uplo(S.uplo)), NoPivot(); check = check)
611+
B = Bidiagonal{T}(diag(S, 0), diag(S, S.uplo == 'U' ? 1 : -1), _sym_uplo(S.uplo))
612+
cholesky!(Hermitian(B, _sym_uplo(S.uplo)), NoPivot(); check = check)
613613
end
614614

615615
# istriu/istril for triangular wrappers of structured matrices

src/structuredbroadcast.jl

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,30 @@ end
230230
# All structured matrices are square, and therefore they only broadcast out if they are size (1, 1)
231231
Broadcast.newindex(D::StructuredMatrix, I::CartesianIndex{2}) = size(D) == (1,1) ? CartesianIndex(1,1) : I
232232

233+
# Recursively replace wrapped matrices by their parents to improve broadcasting performance
234+
# We may do this because the indexing within `copyto!` is restricted to the stored indices
235+
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
236+
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
237+
args = map(x -> preprocess_broadcasted(T, x), bc.args)
238+
Broadcast.broadcasted(bc.f, args...)
239+
end
240+
# fallback case that doesn't unwrap at all
241+
_preprocess_broadcasted(::Type, x) = x
242+
243+
_preprocess_broadcasted(::Type{Diagonal}, d::Diagonal) = d.diag
244+
# fallback for types that might opt into Diagonal-like structured broadcasting, e.g. wrappers
245+
_preprocess_broadcasted(::Type{Diagonal}, d::AbstractMatrix) = diagview(d)
246+
247+
function copy(bc::Broadcasted{StructuredMatrixStyle{Diagonal}})
248+
if isstructurepreserving(bc) || fzeropreserving(bc)
249+
# forward the broadcasting operation to the diagonal
250+
bc2 = preprocess_broadcasted(Diagonal, bc)
251+
return Diagonal(copy(bc2))
252+
else
253+
@invoke copy(bc::Broadcasted)
254+
end
255+
end
256+
233257
function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
234258
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
235259
axs = axes(dest)
@@ -291,13 +315,6 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
291315
return dest
292316
end
293317

294-
# Recursively replace wrapped matrices by their parents to improve broadcasting performance
295-
# We may do this because the indexing within `copyto!` is restricted to the stored indices
296-
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
297-
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
298-
args = map(x -> preprocess_broadcasted(T, x), bc.args)
299-
Broadcast.Broadcasted(bc.f, args, bc.axes)
300-
end
301318
_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A)
302319
_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A)
303320
_preprocess_broadcasted(::Type{UpperHessenberg}, A) = upperhessenbergdata(A)

0 commit comments

Comments
 (0)