11# import LinearAlgebra.MulAddMul
22
3- abstract type MulAddMul{T } end
3+ abstract type MulAddMul{TA,TB } end
44
5- struct AlphaBeta{T} <: MulAddMul{T}
6- α:: T
7- β:: T
8- function AlphaBeta {T} (α,β) where T <: Real
9- new {T} (α,β)
10- end
5+ struct AlphaBeta{TA,TB} <: MulAddMul{TA,TB}
6+ α:: TA
7+ β:: TB
118end
12- @inline AlphaBeta (α:: A ,β:: B ) where {A,B} = AlphaBeta {promote_type(A,B)} (α,β)
139@inline alpha (ab:: AlphaBeta ) = ab. α
1410@inline beta (ab:: AlphaBeta ) = ab. β
1511
16- struct NoMulAdd{T } <: MulAddMul{T } end
17- @inline alpha (ma:: NoMulAdd{T } ) where T = one (T )
18- @inline beta (ma:: NoMulAdd{T } ) where T = zero (T )
12+ struct NoMulAdd{TA,TB } <: MulAddMul{TA,TB } end
13+ @inline alpha (ma:: NoMulAdd{TA,TB } ) where {TA,TB} = one (TA )
14+ @inline beta (ma:: NoMulAdd{TA,TB } ) where {TA,TB} = zero (TB )
1915
2016"""
2117 StaticMatMulLike
@@ -63,12 +59,14 @@ Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}()
6359# 5-argument matrix multiplication
6460# To avoid allocations, strip away Transpose type and store tranpose info in Size
6561@inline LinearAlgebra. mul! (dest:: StaticVecOrMatLike , A:: StaticVecOrMatLike , B:: StaticVecOrMatLike ,
66- α:: Real , β:: Real ) = _mul! (TSize (dest), mul_parent (dest), TSize (A), TSize (B), mul_parent (A), mul_parent (B) ,
62+ α:: Real , β:: Real ) = _mul! (TSize (dest), mul_parent (dest), Size (A), Size (B), A, B ,
6763 AlphaBeta (α,β))
6864
69- @inline LinearAlgebra. mul! (dest:: StaticVecOrMatLike , A:: StaticVecOrMatLike{T} ,
70- B:: StaticVecOrMatLike{T} ) where T =
71- _mul! (TSize (dest), mul_parent (dest), TSize (A), TSize (B), mul_parent (A), mul_parent (B), NoMulAdd {T} ())
65+ @inline function LinearAlgebra. mul! (dest:: StaticVecOrMatLike{TDest} , A:: StaticVecOrMatLike{TA} ,
66+ B:: StaticVecOrMatLike{TB} ) where {TDest,TA,TB}
67+ TMul = typeof (one (TA)* one (TB)+ one (TA)* one (TB))
68+ return _mul! (TSize (dest), mul_parent (dest), Size (A), Size (B), A, B, NoMulAdd {TMul, TDest} ())
69+ end
7270
7371
7472" Calculate the product of the dimensions being multiplied. Useful as a heuristic for unrolling."
@@ -112,55 +110,58 @@ end
112110end
113111
114112" Obtain an expression for the linear index of var[k,j], taking transposes into account"
115- @inline _lind (A:: Type{<:TSize} , k:: Int , j:: Int ) = _lind (:a , A, k, j)
116113function _lind (var:: Symbol , A:: Type{TSize{sa,tA}} , k:: Int , j:: Int ) where {sa,tA}
117114 return uplo_access (sa, var, k, j, tA)
118115end
119116
120117
121118
122119# Matrix-vector multiplication
123- @generated function _mul! (Sc:: TSize{sc} , c:: StaticVecOrMatLike , Sa:: TSize {sa} , Sb:: TSize {sb} ,
124- a :: StaticMatrix , b:: StaticVector , _add:: MulAddMul ,
125- :: Val{col} = Val (1 )) where {sa, sb, sc, col}
120+ @generated function _mul! (Sc:: TSize{sc} , c:: StaticVecOrMatLike , Sa:: Size {sa} , Sb:: Size {sb} ,
121+ wrapped_a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticVector{<:Any, Tb} , _add:: MulAddMul ,
122+ :: Val{col} = Val (1 )) where {sa, sb, sc, col, Ta, Tb }
126123 if sa[2 ] != sb[1 ] || sc[1 ] != sa[1 ]
127124 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
128125 end
129126
130127 if sa[2 ] != 0
131- lhs = [:($ (_lind (:c ,Sc,k,col))) for k = 1 : sa[1 ]]
132- ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
133- [:($ (_lind (Sa,k,j))* b[$ j]) for j = 1 : sa[2 ]]))) for k = 1 : sa[1 ]]
134- exprs = _muladd_expr (lhs, ab, _add)
128+ assign_expr = gen_by_access (wrapped_a) do access_a
129+ lhs = [:($ (_lind (:c ,Sc,k,col))) for k = 1 : sa[1 ]]
130+ ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
131+ [:($ (uplo_access (sa, :a , k, j, access_a)) * b[$ j]) for j = 1 : sa[2 ]]))) for k = 1 : sa[1 ]]
132+ exprs = _muladd_expr (lhs, ab, _add)
133+
134+ return :(@inbounds $ (Expr (:block , exprs... )))
135+ end
135136 else
136137 exprs = [:(c[$ k] = zero (eltype (c))) for k = 1 : sa[1 ]]
138+ assign_expr = :(@inbounds $ (Expr (:block , exprs... )))
137139 end
138140
139141 return quote
140142 # @_inline_meta
141- # α = _add.alpha
142- # β = _add.beta
143143 α = alpha (_add)
144144 β = beta (_add)
145- @inbounds $ (Expr (:block , exprs... ))
145+ a = mul_parent (wrapped_a)
146+ $ assign_expr
146147 return c
147148 end
148149end
149150
150151# Outer product
151- @generated function _mul! (:: TSize{sc} , c:: StaticMatrix , :: TSize {sa,:any } , tsb:: Union{TSize{sb,:transpose},TSize{sb,:adjoint} } ,
152- a:: StaticVector , b:: StaticVector , _add:: MulAddMul ) where {sa, sb, sc}
152+ @generated function _mul! (:: TSize{sc} , c:: StaticMatrix , tsa :: Size {sa} , tsb:: Size{sb } ,
153+ a:: StaticVector , b:: Union{Transpose{<:Any, <: StaticVector}, Adjoint{<:Any, <:StaticVector}} , _add:: MulAddMul ) where {sa, sb, sc}
153154 if sc[1 ] != sa[1 ] || sc[2 ] != sb[2 ]
154155 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
155156 end
156157
157- conjugate_b = isa (tsb, TSize{sb, :adjoint })
158+ conjugate_b = b <: Adjoint
158159
159160 lhs = [:(c[$ (LinearIndices (sc)[i,j])]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
160161 if conjugate_b
161162 ab = [:(a[$ i] * adjoint (b[$ j])) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
162163 else
163- ab = [:(a[$ i] * b[$ j]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
164+ ab = [:(a[$ i] * transpose ( b[$ j]) ) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
164165 end
165166
166167 exprs = _muladd_expr (lhs, ab, _add)
175176end
176177
177178# Matrix-matrix multiplication
178- @generated function _mul! (Sc:: TSize{sc} , c:: StaticMatrixLike ,
179- Sa:: TSize {sa} , Sb:: TSize {sb} ,
180- a:: StaticMatrixLike , b:: StaticMatrixLike ,
179+ @generated function _mul! (Sc:: TSize{sc} , c:: StaticMatMulLike ,
180+ Sa:: Size {sa} , Sb:: Size {sb} ,
181+ a:: StaticMatMulLike , b:: StaticMatMulLike ,
181182 _add:: MulAddMul ) where {sa, sb, sc}
182183 Ta,Tb,Tc = eltype (a), eltype (b), eltype (c)
183184 can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat
199200 if can_blas
200201 return quote
201202 @_inline_meta
202- mul_blas! (Sc, c, Sa, Sb, a, b , _add)
203+ mul_blas! (Sc, c, TSize (a), TSize (b), mul_parent (a), mul_parent (b) , _add)
203204 return c
204205 end
205206 else
@@ -213,18 +214,27 @@ end
213214end
214215
215216
216- @generated function muladd_unrolled_all! (Sc:: TSize{sc} , c :: StaticMatrixLike , Sa:: TSize {sa} , Sb:: TSize {sb} ,
217- a :: StaticMatrixLike , b :: StaticMatrixLike , _add:: MulAddMul ) where {sa, sb, sc}
217+ @generated function muladd_unrolled_all! (Sc:: TSize{sc} , wrapped_c :: StaticMatMulLike , Sa:: Size {sa} , Sb:: Size {sb} ,
218+ wrapped_a :: StaticMatMulLike{<:Any,<:Any,Ta} , wrapped_b :: StaticMatMulLike{<:Any,<:Any,Tb} , _add:: MulAddMul ) where {sa, sb, sc, Ta, Tb }
218219 if ! check_dims (Size (sc),Size (sa),Size (sb))
219220 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
220221 end
221222
222223 if sa[2 ] != 0
223224 lhs = [:($ (_lind (:c , Sc, k1, k2))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
224- ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
225- [:($ (_lind (:a , Sa, k1, j)) * $ (_lind (:b , Sb, j, k2))) for j = 1 : sa[2 ]]
226- ))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
227- exprs = _muladd_expr (lhs, ab, _add)
225+
226+ assign_expr = gen_by_access (wrapped_a, wrapped_b) do access_a, access_b
227+
228+ ab = [:($ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
229+ [:($ (uplo_access (sa, :a , k1, j, access_a)) * $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
230+ ))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
231+
232+ exprs = _muladd_expr (lhs, ab, _add)
233+ return :(@inbounds $ (Expr (:block , exprs... )))
234+ end
235+ else
236+ exprs = [:(c[$ k] = zero (eltype (c))) for k = 1 : sc[1 ]* sc[2 ]]
237+ assign_expr = :(@inbounds $ (Expr (:block , exprs... )))
228238 end
229239
230240 return quote
@@ -233,49 +243,63 @@ end
233243 # β = _add.beta
234244 α = alpha (_add)
235245 β = beta (_add)
236- @inbounds $ (Expr (:block , exprs... ))
246+ c = mul_parent (wrapped_c)
247+ a = mul_parent (wrapped_a)
248+ b = mul_parent (wrapped_b)
249+ $ assign_expr
250+ return c
237251 end
238252end
239253
240254
241- @generated function muladd_unrolled_chunks! (Sc:: TSize{sc} , c :: StaticMatrix , :: TSize {sa,tA } , Sb:: TSize {sb,tB } ,
242- a :: StaticMatrix , b :: StaticMatrix , _add:: MulAddMul ) where {sa, sb, sc, tA, tB }
255+ @generated function muladd_unrolled_chunks! (Sc:: TSize{sc} , wrapped_c :: StaticMatMulLike , :: Size {sa} , Sb:: Size {sb} ,
256+ wrapped_a :: StaticMatMulLike{<:Any,<:Any,Ta} , wrapped_b :: StaticMatMulLike{<:Any,<:Any,Tb} , _add:: MulAddMul ) where {sa, sb, sc, Ta, Tb }
243257 if sb[1 ] != sa[2 ] || sa[1 ] != sc[1 ] || sb[2 ] != sc[2 ]
244258 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
245259 end
246260
261+ # This will not work for Symmetric and Hermitian wrappers of c
262+ lhs = [:($ (_lind (:c , Sc, k1, k2))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
263+
247264 # vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
248265
249266 # Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
250- tmp_type = SVector{sb[1 ], eltype (c)}
251- vect_exprs = [:($ (Symbol (" tmp_$k2 " )) = partly_unrolled_multiply ($ (TSize {sa,tA} ()), $ (TSize {(sb[1],),tB} ()),
252- a, $ (Expr (:call , tmp_type, [:($ (_lind (:b , Sb, i, k2))) for i = 1 : sb[1 ]]. .. )))) for k2 = 1 : sb[2 ]]
267+ tmp_type = SVector{sb[1 ], eltype (wrapped_c)}
253268
254- lhs = [:($ (_lind (:c , Sc, k1, k2))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
255- # exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
256- rhs = [:($ (Symbol (" tmp_$k2 " ))[$ k1]) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
257- exprs = _muladd_expr (lhs, rhs, _add)
269+ assign_expr = gen_by_access (wrapped_a, wrapped_b) do access_a, access_b
270+ vect_exprs = [:($ (Symbol (" tmp_$k2 " )) = partly_unrolled_multiply ($ (Size {sa} ()), $ (Size {(sb[1],)} ()),
271+ a, $ (Expr (:call , tmp_type, [uplo_access (sb, :b , i, k2, access_b) for i = 1 : sb[1 ]]. .. )), $ (Val (access_a)))) for k2 = 1 : sb[2 ]]
272+
273+ # exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
274+ rhs = [:($ (Symbol (" tmp_$k2 " ))[$ k1]) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
275+ exprs = _muladd_expr (lhs, rhs, _add)
258276
277+ return quote
278+ @inbounds $ (Expr (:block , vect_exprs... ))
279+ @inbounds $ (Expr (:block , exprs... ))
280+ end
281+ end
282+
259283 return quote
260284 @_inline_meta
261- # α = _add.alpha
262- # β = _add.beta
263285 α = alpha (_add)
264286 β = beta (_add)
265- @inbounds $ (Expr (:block , vect_exprs... ))
266- @inbounds $ (Expr (:block , exprs... ))
287+ c = mul_parent (wrapped_c)
288+ a = mul_parent (wrapped_a)
289+ b = mul_parent (wrapped_b)
290+ $ assign_expr
267291 end
268292end
269293
270294# @inline partly_unrolled_multiply(Sa::Size, Sb::Size, a::StaticMatrix, b::StaticArray) where {sa, sb, Ta, Tb} =
271295# partly_unrolled_multiply(TSize(Sa), TSize(Sb), a, b)
272- @generated function partly_unrolled_multiply (Sa:: TSize {sa} , :: TSize {sb} , a:: StaticMatrix {<:Any, <:Any, Ta} , b:: StaticArray{<:Tuple, Tb} ) where {sa, sb, Ta, Tb}
296+ @generated function partly_unrolled_multiply (Sa:: Size {sa} , :: Size {sb} , a:: StaticMatMulLike {<:Any, <:Any, Ta} , b:: StaticArray{<:Tuple, Tb} , :: Val{access_a} ) where {sa, sb, Ta, Tb, access_a }
273297 if sa[2 ] != sb[1 ]
274298 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
275299 end
276300
277301 if sa[2 ] != 0
278- exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:($ (_lind ( :a ,Sa,k,j ))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
302+ exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:($ (uplo_access (sa, :a , k, j, access_a ))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
279303 else
280304 exprs = [:(zero (promote_op (matprod,Ta,Tb))) for k = 1 : sa[1 ]]
281305 end
0 commit comments