@@ -13,14 +13,6 @@ import LinearAlgebra: BlasFloat, matprod, mul!
1313@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Adjoint{<:Any,<:StaticVector} ) where {N} = vec (A) * B
1414@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
1515
16- @inline mul! (dest:: StaticVecOrMat , A:: StaticMatrix , B:: StaticVector ) = _mul! (Size (dest), dest, Size (A), Size (B), A, B)
17- @inline mul! (dest:: StaticVecOrMat , A:: StaticMatrix , B:: StaticMatrix ) = _mul! (Size (dest), dest, Size (A), Size (B), A, B)
18- @inline mul! (dest:: StaticVecOrMat , A:: StaticVector , B:: StaticMatrix ) = mul! (dest, reshape (A, Size (Size (A)[1 ], 1 )), B)
19- @inline mul! (dest:: StaticVecOrMat , A:: StaticVector , B:: Transpose{<:Any, <:StaticVector} ) = _mul! (Size (dest), dest, Size (A), Size (B), A, B)
20- @inline mul! (dest:: StaticVecOrMat , A:: StaticVector , B:: Adjoint{<:Any, <:StaticVector} ) = _mul! (Size (dest), dest, Size (A), Size (B), A, B)
21- # @inline *{TA<:LinearAlgebra.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})
22-
23-
2416
2517# Implementations
2618
9789
9890 # Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling)
9991 if sa[1 ]* sa[2 ]* sb[2 ] >= 14 * 14 * 14
92+ Sa = TSize {size(S),false} ()
93+ Sb = TSize {sa,false} ()
94+ Sc = TSize {sb,false} ()
95+ _add = MulAddMul (true ,false )
10096 return quote
10197 @_inline_meta
10298 C = similar (a, T, $ S)
103- mul_blas! ($ S , C, Sa, Sb, a, b)
99+ mul_blas! ($ Sa , C, $ Sa, $ Sb, a, b, $ _add )
104100 return C
105101 end
106102 elseif sa[1 ]* sa[2 ]* sb[2 ] < 8 * 8 * 8
177173 # Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than (possibly) a mutable type. Avoids allocation == faster
178174 tmp_type_in = :(SVector{$ (sb[1 ]), T})
179175 tmp_type_out = :(SVector{$ (sa[1 ]), T})
180- vect_exprs = [:($ (Symbol (" tmp_$k2 " )):: $tmp_type_out = partly_unrolled_multiply (Size (a), Size ($ (sb[1 ])), a,
176+ vect_exprs = [:($ (Symbol (" tmp_$k2 " )):: $tmp_type_out = partly_unrolled_multiply (TSize (a), TSize ($ (sb[1 ])), a,
181177 $ (Expr (:call , tmp_type_in, [Expr (:ref , :b , LinearIndices (sb)[i, k2]) for i = 1 : sb[1 ]]. .. ))):: $tmp_type_out )
182178 for k2 = 1 : sb[2 ]]
183179
@@ -193,201 +189,4 @@ end
193189 end
194190end
195191
196- @generated function partly_unrolled_multiply (:: Size{sa} , :: Size{sb} , a:: StaticMatrix{<:Any, <:Any, Ta} , b:: StaticArray{<:Tuple, Tb} ) where {sa, sb, Ta, Tb}
197- if sa[2 ] != sb[1 ]
198- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
199- end
200-
201- if sa[2 ] != 0
202- exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k, j])]* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
203- else
204- exprs = [:(zero (promote_op (matprod,Ta,Tb))) for k = 1 : sa[1 ]]
205- end
206-
207- return quote
208- $ (Expr (:meta ,:noinline ))
209- @inbounds return SVector (tuple ($ (exprs... )))
210- end
211- end
212-
213- # TODO aliasing problems if c === b?
214- @generated function _mul! (:: Size{sc} , c:: StaticVector , :: Size{sa} , :: Size{sb} , a:: StaticMatrix , b:: StaticVector ) where {sa, sb, sc}
215- if sb[1 ] != sa[2 ] || sc[1 ] != sa[1 ]
216- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
217- end
218-
219- if sa[2 ] != 0
220- exprs = [:(c[$ k] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k, j])]* b[$ j]) for j = 1 : sa[2 ]]))) for k = 1 : sa[1 ]]
221- else
222- exprs = [:(c[$ k] = zero (eltype (c))) for k = 1 : sa[1 ]]
223- end
224-
225- return quote
226- @_inline_meta
227- @inbounds $ (Expr (:block , exprs... ))
228- return c
229- end
230- end
231-
232- @generated function _mul! (:: Size{sc} , c:: StaticMatrix , :: Size{sa} , :: Size{sb} , a:: StaticVector ,
233- b:: Union{Transpose{<:Any, <:StaticVector}, Adjoint{<:Any, <:StaticVector}} ) where {sa, sb, sc}
234- if sa[1 ] != sc[1 ] || sb[2 ] != sc[2 ]
235- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
236- end
237-
238- exprs = [:(c[$ (LinearIndices (sc)[i, j])] = a[$ i] * b[$ j]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
239-
240- return quote
241- @_inline_meta
242- @inbounds $ (Expr (:block , exprs... ))
243- return c
244- end
245- end
246-
247- @generated function _mul! (Sc:: Size{sc} , c:: StaticMatrix{<:Any, <:Any, Tc} , Sa:: Size{sa} , Sb:: Size{sb} , a:: StaticMatrix{<:Any, <:Any, Ta} , b:: StaticMatrix{<:Any, <:Any, Tb} ) where {sa, sb, sc, Ta, Tb, Tc}
248- can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat
249-
250- if can_blas
251- if sa[1 ] * sa[2 ] * sb[2 ] < 4 * 4 * 4
252- return quote
253- @_inline_meta
254- mul_unrolled! (Sc, c, Sa, Sb, a, b)
255- return c
256- end
257- elseif sa[1 ] * sa[2 ] * sb[2 ] < 14 * 14 * 14 # Something seems broken for this one with large matrices (becomes allocating)
258- return quote
259- @_inline_meta
260- mul_unrolled_chunks! (Sc, c, Sa, Sb, a, b)
261- return c
262- end
263- else
264- return quote
265- @_inline_meta
266- mul_blas! (Sc, c, Sa, Sb, a, b)
267- return c
268- end
269- end
270- else
271- if sa[1 ] * sa[2 ] * sb[2 ] < 4 * 4 * 4
272- return quote
273- @_inline_meta
274- mul_unrolled! (Sc, c, Sa, Sb, a, b)
275- return c
276- end
277- else
278- return quote
279- @_inline_meta
280- mul_unrolled_chunks! (Sc, c, Sa, Sb, a, b)
281- return c
282- end
283- end
284- end
285- end
286-
287-
288- @generated function mul_blas! (:: Size{s} , c:: StaticMatrix{<:Any, <:Any, T} , :: Size{sa} , :: Size{sb} , a:: StaticMatrix{<:Any, <:Any, T} , b:: StaticMatrix{<:Any, <:Any, T} ) where {s,sa,sb, T <: BlasFloat }
289- if sb[1 ] != sa[2 ] || sa[1 ] != s[1 ] || sb[2 ] != s[2 ]
290- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $s " ))
291- end
292-
293- if sa[1 ] > 0 && sa[2 ] > 0 && sb[2 ] > 0
294- # This code adapted from `gemm!()` in base/linalg/blas.jl
295-
296- if T == Float64
297- gemm = :dgemm_
298- elseif T == Float32
299- gemm = :sgemm_
300- elseif T == Complex{Float64}
301- gemm = :zgemm_
302- else # T == Complex{Float32}
303- gemm = :cgemm_
304- end
305-
306- blascall = quote
307- ccall ((LinearAlgebra. BLAS. @blasfunc ($ gemm), LinearAlgebra. BLAS. libblas), Nothing,
308- (Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra. BLAS. BlasInt}, Ref{LinearAlgebra. BLAS. BlasInt},
309- Ref{LinearAlgebra. BLAS. BlasInt}, Ref{$ T}, Ptr{$ T}, Ref{LinearAlgebra. BLAS. BlasInt},
310- Ptr{$ T}, Ref{LinearAlgebra. BLAS. BlasInt}, Ref{$ T}, Ptr{$ T},
311- Ref{LinearAlgebra. BLAS. BlasInt}),
312- transA, transB, m, n,
313- ka, alpha, a, strideA,
314- b, strideB, beta, c,
315- strideC)
316- end
317-
318- return quote
319- alpha = one (T)
320- beta = zero (T)
321- transA = ' N'
322- transB = ' N'
323- m = $ (sa[1 ])
324- ka = $ (sa[2 ])
325- kb = $ (sb[1 ])
326- n = $ (sb[2 ])
327- strideA = $ (sa[1 ])
328- strideB = $ (sb[1 ])
329- strideC = $ (s[1 ])
330-
331- $ blascall
332-
333- return c
334- end
335- else
336- throw (DimensionMismatch (" Cannot call BLAS gemm with zero-dimension arrays, attempted $sa * $sb -> $s ." ))
337- end
338- end
339-
340-
341- @generated function mul_unrolled! (:: Size{sc} , c:: StaticMatrix , :: Size{sa} , :: Size{sb} , a:: StaticMatrix , b:: StaticMatrix ) where {sa, sb, sc}
342- if sb[1 ] != sa[2 ] || sa[1 ] != sc[1 ] || sb[2 ] != sc[2 ]
343- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
344- end
345-
346- if sa[2 ] != 0
347- exprs = [:(c[$ (LinearIndices (sc)[k1, k2])] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k1, j])]* b[$ (LinearIndices (sb)[j, k2])]) for j = 1 : sa[2 ]]))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
348- else
349- exprs = [:(c[$ (LinearIndices (sc)[k1, k2])] = zero (eltype (c))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
350- end
351-
352- return quote
353- @_inline_meta
354- @inbounds $ (Expr (:block , exprs... ))
355- end
356- end
357-
358- @generated function mul_unrolled_chunks! (:: Size{sc} , c:: StaticMatrix , :: Size{sa} , :: Size{sb} , a:: StaticMatrix , b:: StaticMatrix ) where {sa, sb, sc}
359- if sb[1 ] != sa[2 ] || sa[1 ] != sc[1 ] || sb[2 ] != sc[2 ]
360- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
361- end
362-
363- # vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
364-
365- # Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
366- tmp_type = SVector{sb[1 ], eltype (c)}
367- vect_exprs = [:($ (Symbol (" tmp_$k2 " )) = partly_unrolled_multiply ($ (Size (sa)), $ (Size (sb[1 ])), a, $ (Expr (:call , tmp_type, [Expr (:ref , :b , LinearIndices (sb)[i, k2]) for i = 1 : sb[1 ]]. .. )))) for k2 = 1 : sb[2 ]]
368-
369- exprs = [:(c[$ (LinearIndices (sc)[k1, k2])] = $ (Symbol (" tmp_$k2 " ))[$ k1]) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
370-
371- return quote
372- @_inline_meta
373- @inbounds $ (Expr (:block , vect_exprs... ))
374- @inbounds $ (Expr (:block , exprs... ))
375- end
376- end
377-
378- # function mul_blas(a, b, c, A, B)
379- # q
380- # end
381-
382- # The idea here is to get pointers to stack variables and call BLAS.
383- # This saves an aweful lot of time compared to copying SArray's to Ref{SArray{...}}
384- # and using BLAS should be fastest for (very) large SArrays
385-
386- # Here is an LLVM function that gets the pointer to its input, %x
387- # After this we would make the ccall above.
388192#
389- # define i8* @f(i32 %x) #0 {
390- # %1 = alloca i32, align 4
391- # store i32 %x, i32* %1, align 4
392- # ret i32* %1
393- # }
0 commit comments