@@ -64,7 +64,7 @@ Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}()
6464
6565@inline function LinearAlgebra. mul! (dest:: StaticVecOrMatLike{TDest} , A:: StaticVecOrMatLike{TA} ,
6666 B:: StaticVecOrMatLike{TB} ) where {TDest,TA,TB}
67- TMul = typeof ( one (TA) * one (TB) + one (TA) * one (TB) )
67+ TMul = promote_op (matprod, TA, TB )
6868 return _mul! (TSize (dest), mul_parent (dest), Size (A), Size (B), A, B, NoMulAdd {TMul, TDest} ())
6969end
7070
111111
112112" Obtain an expression for the linear index of var[k,j], taking transposes into account"
113113function _lind (var:: Symbol , A:: Type{TSize{sa,tA}} , k:: Int , j:: Int ) where {sa,tA}
114- return uplo_access (sa, var, k, j, tA)
114+ ula = uplo_access (sa, var, k, j, tA)
115+ if ula. head == :call && ula. args[1 ] == :transpose
116+ # TODO : can this be properly fixed at all?
117+ return ula. args[2 ]
118+ end
119+ return ula
115120end
116121
117122
126131
127132 if sa[2 ] != 0
128133 assign_expr = gen_by_access (wrapped_a) do access_a
129- lhs = [:($ (_lind (:c ,Sc,k,col))) for k = 1 : sa[1 ]]
130- ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
131- [:($ (uplo_access (sa, :a , k, j, access_a)) * b[$ j]) for j = 1 : sa[2 ]]))) for k = 1 : sa[1 ]]
134+ lhs = [_lind (:c ,Sc,k,col) for k = 1 : sa[1 ]]
135+ ab = [combine_products ([:($ (uplo_access (sa, :a , k, j, access_a)) * b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
132136 exprs = _muladd_expr (lhs, ab, _add)
133137
134138 return :(@inbounds $ (Expr (:block , exprs... )))
@@ -221,13 +225,12 @@ end
221225 end
222226
223227 if sa[2 ] != 0
224- lhs = [:( $ ( _lind (:c , Sc, k1, k2)) ) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
228+ lhs = [_lind (:c , Sc, k1, k2) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
225229
226230 assign_expr = gen_by_access (wrapped_a, wrapped_b) do access_a, access_b
227231
228- ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
229- [:($ (uplo_access (sa, :a , k1, j, access_a)) * $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
230- ))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
232+ ab = [combine_products ([:($ (uplo_access (sa, :a , k1, j, access_a)) * $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
233+ ) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
231234
232235 exprs = _muladd_expr (lhs, ab, _add)
233236 return :(@inbounds $ (Expr (:block , exprs... )))
246249 c = mul_parent (wrapped_c)
247250 a = mul_parent (wrapped_a)
248251 b = mul_parent (wrapped_b)
252+ T = promote_op (matprod,Ta,Tb)
249253 $ assign_expr
250254 return c
251255 end
259263 end
260264
261265 # This will not work for Symmetric and Hermitian wrappers of c
262- lhs = [:( $ ( _lind (:c , Sc, k1, k2)) ) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
266+ lhs = [_lind (:c , Sc, k1, k2) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
263267
264268 # vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
265269
299303 end
300304
301305 if sa[2 ] != 0
302- exprs = [reduce ((ex1,ex2) -> :( + ( $ ex1, $ ex2)), [:($ (uplo_access (sa, :a , k, j, access_a))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
306+ exprs = [combine_products ( [:($ (uplo_access (sa, :a , k, j, access_a))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
303307 else
304308 exprs = [:(zero (promote_op (matprod,Ta,Tb))) for k = 1 : sa[1 ]]
305309 end
0 commit comments