Skip to content

Commit c23f255

Browse files
albertomercuriowsmoses
authored andcommitted
Relav method definition and add check for matrix only
1 parent 6f3bca0 commit c23f255

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/Overlay.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,18 @@ end
114114
## `mul!` goes through too many layers of abstractions and we aren't able to overload
115115
## without specializing on every possible combination of types
116116
for (cT, aT, bT) in (
117-
(:AbstractVector, :DenseMatrix, :AbstractVector),
118-
(:AbstractMatrix, :DenseMatrix, :AbstractVecOrMat),
117+
(:AbstractVector, :AbstractMatrix, :AbstractVector),
118+
(:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat),
119119
)
120120
@eval begin
121121
@reactant_overlay @noinline function LinearAlgebra.mul!(
122122
C::$cT, A::$aT, B::$bT, α::Number, β::Number
123123
)
124-
A, B = aos_to_soa(A), aos_to_soa(B)
124+
A2, B2 = aos_to_soa(A), aos_to_soa(B)
125125
C2 = aos_to_soa(C)
126-
if use_overlayed_version((C2, A, B))
127-
TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β)
126+
# A2 can also be a SparseMatrix, which should be handled by its own methods
127+
if use_overlayed_version(A2) && use_overlayed_version((C2, A2, B2))
128+
TracedLinearAlgebra.overloaded_mul!(C2, A2, B2, α, β)
128129
if C2 !== C
129130
C .= C2
130131
end

0 commit comments

Comments
 (0)