Skip to content

Commit 8a6329b

Browse files
authored
More diag mul methods (#631)
1 parent a902e9b commit 8a6329b

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

src/host/linalg.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray},
266266
end
267267

268268
function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray},
269-
A::AbstractGPUArray,
270-
B::AbstractGPUArray)
269+
A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}},
270+
B::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}) where {T}
271271
dc = C.diag
272272
d = length(dc)
273273
m, n = size(A, 1), size(A, 2)
@@ -282,23 +282,22 @@ end
282282

283283
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
284284
D::Diagonal{<:Any, <:AbstractGPUArray},
285-
A::AbstractGPUVecOrMat)
285+
A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}) where {T}
286286
dd = D.diag
287287
d = length(dd)
288288
m, n = size(A, 1), size(A, 2)
289289
m′, n′ = size(B, 1), size(B, 2)
290290
m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d"))
291291
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
292292
@. B = dd * A
293-
294293
B
295294
end
296295

297296
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
298297
D::Diagonal{<:Any, <:AbstractGPUArray},
299-
A::AbstractGPUVecOrMat,
298+
A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}},
300299
α::Number,
301-
β::Number)
300+
β::Number) where {T}
302301
dd = D.diag
303302
d = length(dd)
304303
m, n = size(A, 1), size(A, 2)

test/testsuite/linalg.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@
238238
mul!(X, D, B)
239239
mul!(Y, Diagonal(collect(d)), collect(B))
240240
@test collect(X) Y
241+
mul!(X, D, adjoint(B))
242+
mul!(Y, Diagonal(collect(d)), collect(adjoint(B)))
243+
@test collect(X) Y
241244
mul!(X, D, B, α, β)
242245
mul!(Y, Diagonal(collect(d)), collect(B), α, β)
243246
@test collect(X) Y
@@ -259,6 +262,11 @@
259262
C = Diagonal(d)
260263
mul!(C, a, b)
261264
@test collect(C) Diagonal(collect(a) * collect(b))
265+
a = transpose(AT(diagm(rand(elty, n))))
266+
b = adjoint(AT(diagm(rand(elty, n))))
267+
C = Diagonal(d)
268+
mul!(C, a, b)
269+
@test collect(C) Diagonal(collect(a) * collect(b))
262270
end
263271
end
264272

0 commit comments

Comments
 (0)