Skip to content

Commit c96c494

Browse files
committed
Out-of-place triu/tril for Symmetric in each branch
1 parent 8d6ca14 commit c96c494

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

src/symmetric.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -520,49 +520,49 @@ Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo)
520520
# tril/triu
521521
function tril(A::Hermitian, k::Integer=0)
522522
if A.uplo == 'U' && k <= 0
523-
return tril!(copy(A.data'),k)
523+
return tril_maybe_inplace(copy(A.data'),k)
524524
elseif A.uplo == 'U' && k > 0
525-
return tril!(copy(A.data'),-1) + tril!(triu(A.data),k)
525+
return tril_maybe_inplace(copy(A.data'),-1) + tril_maybe_inplace(triu(A.data),k)
526526
elseif A.uplo == 'L' && k <= 0
527527
return tril(A.data,k)
528528
else
529-
return tril(A.data,-1) + tril!(triu!(copy(A.data')),k)
529+
return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(A.data')),k)
530530
end
531531
end
532532

533533
function tril(A::Symmetric, k::Integer=0)
534534
if A.uplo == 'U' && k <= 0
535-
return tril!(copy(transpose(A.data)),k)
535+
return tril_maybe_inplace(copy(transpose(A.data)),k)
536536
elseif A.uplo == 'U' && k > 0
537-
return tril!(copy(transpose(A.data)),-1) + tril!(triu(A.data),k)
537+
return tril_maybe_inplace(copy(transpose(A.data)),-1) + tril_maybe_inplace(triu(A.data),k)
538538
elseif A.uplo == 'L' && k <= 0
539539
return tril(A.data,k)
540540
else
541-
return tril(A.data,-1) + tril!(triu!(copy(transpose(A.data))),k)
541+
return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(transpose(A.data))),k)
542542
end
543543
end
544544

545545
function triu(A::Hermitian, k::Integer=0)
546546
if A.uplo == 'U' && k >= 0
547547
return triu(A.data,k)
548548
elseif A.uplo == 'U' && k < 0
549-
return triu(A.data,1) + triu!(tril!(copy(A.data')),k)
549+
return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(A.data')),k)
550550
elseif A.uplo == 'L' && k >= 0
551-
return triu!(copy(A.data'),k)
551+
return triu_maybe_inplace(copy(A.data'),k)
552552
else
553-
return triu!(copy(A.data'),1) + triu!(tril(A.data),k)
553+
return triu_maybe_inplace(copy(A.data'),1) + triu_maybe_inplace(tril(A.data),k)
554554
end
555555
end
556556

557557
function triu(A::Symmetric, k::Integer=0)
558558
if A.uplo == 'U' && k >= 0
559559
return triu(A.data,k)
560560
elseif A.uplo == 'U' && k < 0
561-
return triu(A.data,1) + triu!(tril!(copy(transpose(A.data))),k)
561+
return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(transpose(A.data))),k)
562562
elseif A.uplo == 'L' && k >= 0
563-
return triu!(copy(transpose(A.data)),k)
563+
return triu_maybe_inplace(copy(transpose(A.data)),k)
564564
else
565-
return triu!(copy(transpose(A.data)),1) + triu!(tril(A.data),k)
565+
return triu_maybe_inplace(copy(transpose(A.data)),1) + triu_maybe_inplace(tril(A.data),k)
566566
end
567567
end
568568

src/triangular.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,11 @@ function tril!(A::UnitLowerTriangular, k::Integer=0)
530530
return tril!(LowerTriangular(A.data), k)
531531
end
532532

533+
tril_maybe_inplace(A, k::Integer=0) = tril(A, k)
534+
triu_maybe_inplace(A, k::Integer=0) = triu(A, k)
535+
tril_maybe_inplace(A::StridedMatrix, k::Integer=0) = tril!(A, k)
536+
triu_maybe_inplace(A::StridedMatrix, k::Integer=0) = triu!(A, k)
537+
533538
adjoint(A::LowerTriangular) = UpperTriangular(adjoint(A.data))
534539
adjoint(A::UpperTriangular) = LowerTriangular(adjoint(A.data))
535540
adjoint(A::UnitLowerTriangular) = UnitUpperTriangular(adjoint(A.data))

test/symmetric.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,4 +1350,27 @@ end
13501350
@test LinearAlgebra.uplo(H) == :L
13511351
end
13521352

1353+
@testset "triu/tril with immutable arrays" begin
1354+
struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T}
1355+
a :: A
1356+
end
1357+
Base.size(A::ImmutableMatrix) = size(A.a)
1358+
Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j)
1359+
Base.copy(A::ImmutableMatrix) = A
1360+
LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a))
1361+
LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a))
1362+
1363+
A = ImmutableMatrix([1 2; 3 4])
1364+
for T in (Symmetric, Hermitian), uplo in (:U, :L)
1365+
H = T(A, uplo)
1366+
MH = Matrix(H)
1367+
@test triu(H,-1) == triu(MH,-1)
1368+
@test triu(H) == triu(MH)
1369+
@test triu(H,1) == triu(MH,1)
1370+
@test tril(H,1) == tril(MH,1)
1371+
@test tril(H) == tril(MH)
1372+
@test tril(H,-1) == tril(MH,-1)
1373+
end
1374+
end
1375+
13531376
end # module TestSymmetric

0 commit comments

Comments
 (0)