Skip to content

Commit e8915ce

Browse files
Relav method definition and add check for matrix only
1 parent 3cd1f60 commit e8915ce

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/Overlay.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,18 @@ end
112112
## `mul!` goes through too many layers of abstractions and we aren't able to overload
113113
## without specializing on every possible combination of types
114114
for (cT, aT, bT) in (
115-
(:AbstractVector, :DenseMatrix, :AbstractVector),
116-
(:AbstractMatrix, :DenseMatrix, :AbstractVecOrMat),
115+
(:AbstractVector, :AbstractMatrix, :AbstractVector),
116+
(:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat),
117117
)
118118
@eval begin
119119
@reactant_overlay @noinline function LinearAlgebra.mul!(
120120
C::$cT, A::$aT, B::$bT, α::Number, β::Number
121121
)
122-
A, B = aos_to_soa(A), aos_to_soa(B)
122+
A2, B2 = aos_to_soa(A), aos_to_soa(B)
123123
C2 = aos_to_soa(C)
124-
if use_overlayed_version((C2, A, B))
125-
TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β)
124+
# A2 can also be a SparseMatrix, which should be handled by its own methods
125+
if use_overlayed_version(A2) && use_overlayed_version((C2, A2, B2))
126+
TracedLinearAlgebra.overloaded_mul!(C2, A2, B2, α, β)
126127
if C2 !== C
127128
C .= C2
128129
end

0 commit comments

Comments
 (0)