@@ -114,17 +114,18 @@ end
114114# # `mul!` goes through too many layers of abstractions and we aren't able to overload
115115# # without specializing on every possible combination of types
116116for (cT, aT, bT) in (
117- (:AbstractVector , :DenseMatrix , :AbstractVector ),
118- (:AbstractMatrix , :DenseMatrix , :AbstractVecOrMat ),
117+ (:AbstractVector , :AbstractMatrix , :AbstractVector ),
118+ (:AbstractMatrix , :AbstractMatrix , :AbstractVecOrMat ),
119119)
120120 @eval begin
121121 @reactant_overlay @noinline function LinearAlgebra. mul! (
122122 C:: $cT , A:: $aT , B:: $bT , α:: Number , β:: Number
123123 )
124- A, B = aos_to_soa (A), aos_to_soa (B)
124+ A2, B2 = aos_to_soa (A), aos_to_soa (B)
125125 C2 = aos_to_soa (C)
126- if use_overlayed_version ((C2, A, B))
127- TracedLinearAlgebra. overloaded_mul! (C2, A, B, α, β)
126+ # A2 can also be a SparseMatrix, which should be handled by its own methods
127+ if use_overlayed_version (A2) && use_overlayed_version ((C2, A2, B2))
128+ TracedLinearAlgebra. overloaded_mul! (C2, A2, B2, α, β)
128129 if C2 != = C
129130 C .= C2
130131 end
0 commit comments