diff --git a/src/host/linalg.jl b/src/host/linalg.jl index bc910be4e..859382349 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -373,7 +373,7 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) end if isempty(A) || isempty(B) - return fill!(C, zero(R)) + return rmul!(C, add.beta) end @kernel function matmatmul_kernel!(C, A, B) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index b042d9408..30fd285fb 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -420,6 +420,16 @@ end @test compare(mul!, AT, rand(T, 2,2), rand(T, 2,1), f(rand(T, 2))) end end + + @testset "$T gemv zero-dim" for T in eltypes + y, A, x = rand(T, 4), rand(T, 4, 0), rand(T, 0) + + @test compare(*, AT, A, x) + @test compare(mul!, AT, y, A, x) + + y = rand(T, 4) + @test compare(mul!, AT, y, A, x, Ref(T(4)), Ref(T(5))) + end end @testsuite "linalg/mul!/matrix-matrix" (AT, eltypes)->begin @@ -434,6 +444,16 @@ end @test compare(mul!, AT, C, f(A), g(B), Ref(T(4)), Ref(T(5))) @test typeof(AT(rand(T, 3, 3)) * AT(rand(T, 3, 3))) <: AbstractMatrix end + + @testset "$T gemm zero-dim" for T in eltypes + A, B, C = rand(T, 4, 0), rand(T, 0, 4), rand(T, 4, 4) + + @test compare(*, AT, A, B) + @test compare(mul!, AT, C, A, B) + + C = rand(T, 4, 4) + @test compare(mul!, AT, C, A, B, Ref(T(4)), Ref(T(5))) + end end @testsuite "linalg/norm" (AT, eltypes)->begin