@@ -112,17 +112,18 @@ end
112112# # `mul!` goes through too many layers of abstractions and we aren't able to overload
113113# # without specializing on every possible combination of types
114114for (cT, aT, bT) in (
115- (:AbstractVector , :DenseMatrix , :AbstractVector ),
116- (:AbstractMatrix , :DenseMatrix , :AbstractVecOrMat ),
115+ (:AbstractVector , :AbstractMatrix , :AbstractVector ),
116+ (:AbstractMatrix , :AbstractMatrix , :AbstractVecOrMat ),
117117)
118118 @eval begin
119119 @reactant_overlay @noinline function LinearAlgebra. mul! (
120120 C:: $cT , A:: $aT , B:: $bT , α:: Number , β:: Number
121121 )
122- A, B = aos_to_soa (A), aos_to_soa (B)
122+ A2, B2 = aos_to_soa (A), aos_to_soa (B)
123123 C2 = aos_to_soa (C)
124- if use_overlayed_version ((C2, A, B))
125- TracedLinearAlgebra. overloaded_mul! (C2, A, B, α, β)
124+ # A2 can also be a SparseMatrix, which should be handled by its own methods
125+ if use_overlayed_version (A2) && use_overlayed_version ((C2, A2, B2))
126+ TracedLinearAlgebra. overloaded_mul! (C2, A2, B2, α, β)
126127 if C2 != = C
127128 C .= C2
128129 end
0 commit comments