@@ -266,8 +266,8 @@ function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray},
266266end
267267
268268function LinearAlgebra. mul! (C:: Diagonal{<:Any, <:AbstractGPUArray} ,
269- A:: AbstractGPUArray ,
270- B:: AbstractGPUArray )
269+ A:: Union{ AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}} ,
270+ B:: Union{ AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}} ) where {T}
271271 dc = C. diag
272272 d = length (dc)
273273 m, n = size (A, 1 ), size (A, 2 )
@@ -282,23 +282,22 @@ end
282282
283283function LinearAlgebra. mul! (B:: AbstractGPUVecOrMat ,
284284 D:: Diagonal{<:Any, <:AbstractGPUArray} ,
285- A:: AbstractGPUVecOrMat )
285+ A:: Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}} ) where {T}
286286 dd = D. diag
287287 d = length (dd)
288288 m, n = size (A, 1 ), size (A, 2 )
289289 m′, n′ = size (B, 1 ), size (B, 2 )
290290 m == d || throw (DimensionMismatch (" right hand side has $m rows but D is $d by $d " ))
291291 (m, n) == (m′, n′) || throw (DimensionMismatch (" expect output to be $m by $n , but got $m′ by $n′ " ))
292292 @. B = dd * A
293-
294293 B
295294end
296295
297296function LinearAlgebra. mul! (B:: AbstractGPUVecOrMat ,
298297 D:: Diagonal{<:Any, <:AbstractGPUArray} ,
299- A:: AbstractGPUVecOrMat ,
298+ A:: Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}} ,
300299 α:: Number ,
301- β:: Number )
300+ β:: Number ) where {T}
302301 dd = D. diag
303302 d = length (dd)
304303 m, n = size (A, 1 ), size (A, 2 )
0 commit comments