@@ -4,12 +4,20 @@ import LinearAlgebra: BlasFloat, matprod, mul!
44# Manage dispatch of * and mul!
55# TODO Adjoint? (Inner product?)
66
7+ """
8+ StaticMatMulLike
9+
10+ Static wrappers used for multiplication dispatch.
11+ """
712const StaticMatMulLike{s1, s2, T} = Union{
813 StaticMatrix{s1, s2, T},
914 Symmetric{T, <: StaticMatrix{s1, s2, T} },
1015 Hermitian{T, <: StaticMatrix{s1, s2, T} },
1116 LowerTriangular{T, <: StaticMatrix{s1, s2, T} },
12- UpperTriangular{T, <: StaticMatrix{s1, s2, T} }}
17+ UpperTriangular{T, <: StaticMatrix{s1, s2, T} },
18+ Adjoint{T, <: StaticMatrix{s1, s2, T} },
19+ Transpose{T, <: StaticMatrix{s1, s2, T} }}
20+
1321
1422@inline * (A:: StaticMatMulLike , B:: AbstractVector ) = _mul (Size (A), A, B)
1523@inline * (A:: StaticMatMulLike , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
@@ -20,6 +28,18 @@ const StaticMatMulLike{s1, s2, T} = Union{
2028@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Adjoint{<:Any,<:StaticVector} ) where {N} = vec (A) * B
2129@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
2230
31+ """
32+ gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :a)
33+
34+ Statically generate outer code for fully unrolled multiplication loops.
35+ Returned code does wrapper-specific tests (for example if a symmetric matrix view is
36+ `U` or `L`) and the body of the if expression is then generated by function `expr_gen`.
37+ The function `expr_gen` receives access pattern description symbol as its argument
38+ and this symbol is then consumed by uplo_access to generate the right code for matrix
39+ element access.
40+
41+ The name of the matrix to test is indicated by `asym`.
42+ """
2343function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , asym = :a )
2444 return expr_gen (:any )
2545end
4767function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :a )
4868 return expr_gen (:lower_triangular )
4969end
70+ function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticMatrix}} , asym = :a )
71+ return expr_gen (:transpose )
72+ end
73+ function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticMatrix}} , asym = :a )
74+ return expr_gen (:adjoint )
75+ end
76+ """
77+ gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray})
78+
79+ Simiar to gen_by_access with only one type argument. The difference is that tests for both
80+ arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments,
81+ first for matrix `a` and the second for matrix `b`.
82+ """
5083function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type )
5184 return quote
5285 return $ (gen_by_access (b, :b ) do access_b
@@ -94,6 +127,20 @@ function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix
94127 end )
95128 end
96129end
130+ function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticMatrix}} , b:: Type )
131+ return quote
132+ return $ (gen_by_access (b, :b ) do access_b
133+ expr_gen (:transpose , access_b)
134+ end )
135+ end
136+ end
137+ function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticMatrix}} , b:: Type )
138+ return quote
139+ return $ (gen_by_access (b, :b ) do access_b
140+ expr_gen (:adjoint , access_b)
141+ end )
142+ end
143+ end
97144
98145"""
99146 mul_result_structure(a::Type, b::Type)
@@ -111,6 +158,14 @@ function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::LowerT
111158 return LowerTriangular
112159end
113160
161+ """
162+ uplo_access(sa, asym, k, j, uplo)
163+
164+ Generate code for matrix element access, for a matrix of size `sa` locally referred to
165+ as `asym` in the context where the result will be used. Both indices `k` and `j` need to be
166+ statically known for this function to work. `uplo` is the access pattern mode generated
167+ by the `gen_by_access` function.
168+ """
114169function uplo_access (sa, asym, k, j, uplo)
115170 if uplo == :any
116171 return :($ asym[$ (LinearIndices (sa)[k, j])])
@@ -150,6 +205,10 @@ function uplo_access(sa, asym, k, j, uplo)
150205 else
151206 return :(zero (T))
152207 end
208+ elseif uplo == :transpose
209+ return :($ asym[$ (LinearIndices (sa)[j, k])])
210+ elseif uplo == :ajoint
211+ return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
153212 end
154213end
155214
@@ -216,23 +275,35 @@ end
216275 end
217276end
218277
278+ _unstatic_array (:: Type{TSA} ) where {S, T, N, TSA<: StaticArray{S,T,N} } = AbstractArray{T,N}
279+ for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular]
280+ @eval _unstatic_array (:: Type{$TWR{T,TSA}} ) where {S, T, N, TSA<: StaticArray{S,T,N} } = $ TWR{T,<: AbstractArray{T,N} }
281+ end
282+
219283@generated function _mul (Sa:: Size{sa} , Sb:: Size{sb} , a:: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
220284 # Heuristic choice for amount of codegen
221285 if sa[1 ]* sa[2 ]* sb[2 ] <= 8 * 8 * 8 || ! (a <: StaticMatrix ) || ! (b <: StaticMatrix )
222286 return quote
223287 @_inline_meta
224288 return mul_unrolled (Sa, Sb, a, b)
225289 end
226- elseif sa[1 ] <= 14 && sa[2 ] <= 14 && sb[2 ] <= 14
290+ elseif a <: StaticMatrix && b <: StaticMatrix && sa[1 ] <= 14 && sa[2 ] <= 14 && sb[2 ] <= 14
227291 return quote
228292 @_inline_meta
229293 return mul_unrolled_chunks (Sa, Sb, a, b)
230294 end
231- else
295+ elseif a <: StaticMatrix && b <: StaticMatrix
232296 return quote
233297 @_inline_meta
234298 return mul_loop (Sa, Sb, a, b)
235299 end
300+ else
301+ # we don't have any special code for handling this case so let's fall back to
302+ # the generic implementation of matrix multiplication
303+ return quote
304+ @_inline_meta
305+ return invoke (* , Tuple{$ (_unstatic_array (a)),$ (_unstatic_array (b))}, a, b)
306+ end
236307 end
237308end
238309
0 commit comments