Skip to content

Commit fbbdb28

Browse files
authored
Clean up juliaBLAS.jl and test it (#95)
1 parent 887fbd3 commit fbbdb28

File tree

3 files changed

+58
-100
lines changed

3 files changed

+58
-100
lines changed

src/juliaBLAS.jl

Lines changed: 28 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@ function rankUpdate!(A::StridedMatrix, x::StridedVector, y::StridedVector, α::N
2424
end
2525
end
2626

27-
# Deprecated 11 October 2018
28-
Base.@deprecate rankUpdate!::Number, x::StridedVector, y::StridedVector, A::StridedMatrix) rankUpdate!(A, x, y, α)
29-
3027
## Hermitian
31-
rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}, α::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.syr!(A.uplo, α, a, A.data)
32-
rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}) where {T<:BlasReal,S<:StridedMatrix} = rankUpdate!(one(T), a, A)
28+
function rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}, α::T) where {T<:BlasReal,S<:StridedMatrix}
29+
BLAS.syr!(A.uplo, α, a, A.data)
30+
return A
31+
end
32+
function rankUpdate!(A::Hermitian{Complex{T},S}, a::StridedVector{Complex{T}}, α::T) where {T<:BlasReal,S<:StridedMatrix}
33+
BLAS.her!(A.uplo, α, a, A.data)
34+
return A
35+
end
36+
rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}) where {T<:BlasFloat,S<:StridedMatrix} = rankUpdate!(A, a, one(real(T)))
3337

3438
### Generic
3539
function rankUpdate!(A::Hermitian, a::StridedVector, α::Real)
@@ -44,15 +48,18 @@ function rankUpdate!(A::Hermitian, a::StridedVector, α::Real)
4448
return A
4549
end
4650

47-
# Deprecated 11 October 2018
48-
Base.@deprecate rankUpdate!::Real, a::StridedVector, A::Hermitian) rankUpdate!(A, a, α)
49-
5051
# Rank k update
5152
## Real
52-
rankUpdate!(C::HermOrSym{T,S}, A::StridedMatrix{T}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.syrk!(C.uplo, 'N', α, A, β, C.data)
53+
function rankUpdate!(C::HermOrSym{T,S}, A::StridedMatrix{T}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix}
54+
BLAS.syrk!(C.uplo, 'N', α, A, β, C.data)
55+
return C
56+
end
5357

5458
## Complex
55-
rankUpdate!(C::Hermitian{T,S}, A::StridedMatrix{Complex{T}}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.herk!(C.uplo, 'N', α, A, β, C.data)
59+
function rankUpdate!(C::Hermitian{Complex{T},S}, A::StridedMatrix{Complex{T}}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix}
60+
BLAS.herk!(C.uplo, 'N', α, A, β, C.data)
61+
return C
62+
end
5663

5764
### Generic
5865
function rankUpdate!(C::Hermitian, A::StridedVecOrMat, α::Real)
@@ -80,93 +87,14 @@ function rankUpdate!(C::Hermitian, A::StridedVecOrMat, α::Real)
8087
return C
8188
end
8289

83-
# Deprecated 11 October 2018
84-
Base.@deprecate rankUpdate!::Real, A::StridedVecOrMat, C::Hermitian) rankUpdate!(C, A, α)
85-
Base.@deprecate rankUpdate!::Real, A::StridedVecOrMat, β::Real, C::Hermitian) rankUpdate!(C, A, α, β)
86-
87-
if VERSION < v"1.3.0-alpha.115" # Project.toml has julia = "1.6" so this block should no longer be necessary
88-
# BLAS style mul!
89-
## gemv
90-
mul!(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}, α::T, β::T) where {T<:BlasFloat} = gemv!('N', α, A, x, β, y)
91-
mul!(y::StridedVector{T}, A::Adjoint{T,<:StridedMatrix{T}}, x::StridedVector{T}, α::T, β::T) where {T<:BlasFloat} = gemv!('C', α, parent(adjA), x, β, y)
92-
93-
## gemm
94-
mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}, α::T, β::T) where {T<:BlasFloat} = BLAS.gemm!('N', 'N', α, A, B, β, C)
95-
mul!(C::StridedMatrix{T}, adjA::Adjoint{T,<:StridedMatrix{T}}, B::StridedMatrix{T}, α::T, β::T) where {T<:BlasFloat} = BLAS.gemm!('C', 'N', α, parent(adjA), B, β, C)
96-
# Not optimized since it is a generic fallback. Can probably soon be removed when the signatures in base have been updated.
97-
function mul!(C::StridedVecOrMat,
98-
A::StridedMatrix,
99-
B::StridedVecOrMat,
100-
α::Number,
101-
β::Number)
102-
103-
m, n = size(C, 1), size(C, 2)
104-
k = size(A, 2)
105-
106-
if β != 1
107-
if β == 0
108-
fill!(C, 0)
109-
else
110-
rmul!(C, β)
111-
end
112-
end
113-
for j = 1:n
114-
for i = 1:m
115-
for l = 1:k
116-
C[i,j] += α*A[i,l]*B[l,j]
117-
end
118-
end
119-
end
120-
return C
121-
end
122-
function mul!(C::StridedVecOrMat,
123-
adjA::Adjoint{<:Number,<:StridedMatrix},
124-
B::StridedVecOrMat,
125-
α::Number,
126-
β::Number)
127-
128-
A = parent(adjA)
129-
m, n = size(C, 1), size(C, 2)
130-
k = size(A, 1)
131-
132-
if β != 1
133-
if β == 0
134-
fill!(C, 0)
135-
else
136-
rmul!(C, β)
137-
end
138-
end
139-
for j = 1:n
140-
for i = 1:m
141-
for l = 1:k
142-
C[i,j] += α*A[l,i]'*B[l,j]
143-
end
144-
end
145-
end
146-
return C
147-
end
148-
149-
## trmm like
150-
### BLAS versions
151-
mul!(A::UpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'N', 'N', α, A.data, B)
152-
mul!(A::LowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'N', 'N', α, A.data, B)
153-
mul!(A::UnitUpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'N', 'U', α, A.data, B)
154-
mul!(A::UnitLowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'N', 'U', α, A.data, B)
155-
mul!(A::Adjoint{T,UpperTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'C', 'N', α, parent(A).data, B)
156-
mul!(A::Adjoint{T,LowerTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'C', 'N', α, parent(A).data, B)
157-
mul!(A::Adjoint{T,UnitUpperTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'C', 'U', α, parent(A).data, B)
158-
mul!(A::Adjoint{T,UnitLowerTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'C', 'U', α, parent(A).data, B)
159-
160-
end # VERSION
161-
16290
### Generic fallbacks
16391
function lmul!(A::UpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S}
16492
AA = A.data
16593
m, n = size(B)
166-
for i = 1:m
167-
for j = 1:n
94+
for i 1:m
95+
for j 1:n
16896
B[i,j] = α*AA[i,i]*B[i,j]
169-
for l = i + 1:m
97+
for l (i + 1):m
17098
B[i,j] += α*AA[i,l]*B[l,j]
17199
end
172100
end
@@ -176,10 +104,10 @@ end
176104
function lmul!(A::LowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S}
177105
AA = A.data
178106
m, n = size(B)
179-
for i = m:-1:1
180-
for j = 1:n
107+
for i m:-1:1
108+
for j 1:n
181109
B[i,j] = α*AA[i,i]*B[i,j]
182-
for l = 1:i - 1
110+
for l 1:(i - 1)
183111
B[i,j] += α*AA[i,l]*B[l,j]
184112
end
185113
end
@@ -189,11 +117,11 @@ end
189117
function lmul!(A::UnitUpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S}
190118
AA = A.data
191119
m, n = size(B)
192-
for i = 1:m
193-
for j = 1:n
120+
for i 1:m
121+
for j 1:n
194122
B[i,j] = α*B[i,j]
195-
for l = i + 1:m
196-
B[i,j] = α*AA[i,l]*B[l,j]
123+
for l (i + 1):m
124+
B[i,j] += α*AA[i,l]*B[l,j]
197125
end
198126
end
199127
end
@@ -205,7 +133,7 @@ function lmul!(A::UnitLowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T
205133
for i = m:-1:1
206134
for j = 1:n
207135
B[i,j] = α*B[i,j]
208-
for l = 1:i - 1
136+
for l = 1:(i - 1)
209137
B[i,j] += α*AA[i,l]*B[l,j]
210138
end
211139
end

test/juliaBLAS.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using Test, GenericLinearAlgebra, LinearAlgebra
2+
3+
@testset "rankUpdate!" begin
4+
A, B, x = (Hermitian(randn(5, 5)), randn(5, 2), randn(5))
5+
Ac, Bc, xc = (
6+
Hermitian(complex.(randn(5, 5), randn(5, 5))),
7+
complex.(randn(5, 2), randn(5, 2)),
8+
complex.(randn(5), randn(5)),
9+
)
10+
@test rankUpdate!(copy(A), x) A .+ x.*x'
11+
@test rankUpdate!(copy(Ac), xc) Ac .+ xc.*xc'
12+
13+
@test rankUpdate!(copy(A), B, 0.5, 0.5) 0.5*A + 0.5*B*B'
14+
@test rankUpdate!(copy(Ac), Bc, 0.5, 0.5) 0.5*Ac + 0.5*Bc*Bc'
15+
16+
@test invoke(rankUpdate!, Tuple{Hermitian,StridedVecOrMat,Real}, copy(Ac), Bc, 1.0)
17+
rankUpdate!(copy(Ac), Bc, 1.0, 1.0)
18+
end
19+
20+
@testset "triangular multiplication: $(typeof(T))" for T (
21+
UpperTriangular(complex.(randn(5, 5), randn(5, 5))),
22+
UnitUpperTriangular(complex.(randn(5, 5), randn(5, 5))),
23+
LowerTriangular(complex.(randn(5, 5), randn(5, 5))),
24+
UnitLowerTriangular(complex.(randn(5, 5), randn(5, 5))),
25+
)
26+
B = complex.(randn(5, 5), randn(5, 5))
27+
@test lmul!(T, copy(B), complex(0.5, 0.5)) T*B*complex(0.5, 0.5)
28+
@test lmul!(T', copy(B), complex(0.5, 0.5)) T'*B*complex(0.5, 0.5)
29+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22

33
# @testset "The LinearAlgebra Test Suite" begin
4+
include("juliaBLAS.jl")
45
include("cholesky.jl")
56
include("qr.jl")
67
include("eigenselfadjoint.jl")

0 commit comments

Comments
 (0)