@@ -4,26 +4,10 @@ import LinearAlgebra: BlasFloat, matprod, mul!
44# Manage dispatch of * and mul!
55# TODO Adjoint? (Inner product?)
66
7- """
8- StaticMatMulLike
9-
10- Static wrappers used for multiplication dispatch.
11- """
12- const StaticMatMulLike{s1, s2, T} = Union{
13- StaticMatrix{s1, s2, T},
14- Symmetric{T, <: StaticMatrix{s1, s2, T} },
15- Hermitian{T, <: StaticMatrix{s1, s2, T} },
16- LowerTriangular{T, <: StaticMatrix{s1, s2, T} },
17- UpperTriangular{T, <: StaticMatrix{s1, s2, T} },
18- UnitLowerTriangular{T, <: StaticMatrix{s1, s2, T} },
19- UnitUpperTriangular{T, <: StaticMatrix{s1, s2, T} },
20- UpperHessenberg{T, <: StaticMatrix{s1, s2, T} },
21- Adjoint{T, <: StaticMatrix{s1, s2, T} },
22- Transpose{T, <: StaticMatrix{s1, s2, T} }}
23-
24-
25- @inline * (A:: StaticMatMulLike , B:: AbstractVector ) = _mul (Size (A), A, B)
7+ # *(A::StaticMatMulLike, B::AbstractVector) causes an ambiguity with SparseArrays
8+ @inline * (A:: StaticMatrix , B:: AbstractVector ) = _mul (Size (A), A, B)
269@inline * (A:: StaticMatMulLike , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
10+ @inline * (A:: StaticMatrix , B:: StaticVector ) = _mul (Size (A), Size (B), A, B)
2711@inline * (A:: StaticMatMulLike , B:: StaticMatMulLike ) = _mul (Size (A), Size (B), A, B)
2812@inline * (A:: StaticVector , B:: StaticMatMulLike ) = * (reshape (A, Size (Size (A)[1 ], 1 )), B)
2913@inline * (A:: StaticVector , B:: Transpose{<:Any, <:StaticVector} ) = _mul (Size (A), Size (B), A, B)
@@ -32,7 +16,7 @@ const StaticMatMulLike{s1, s2, T} = Union{
3216@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
3317
3418"""
35- gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :a )
19+ gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a )
3620
3721Statically generate outer code for fully unrolled multiplication loops.
3822Returned code does wrapper-specific tests (for example if a symmetric matrix view is
@@ -43,10 +27,10 @@ element access.
4327
4428The name of the matrix to test is indicated by `asym`.
4529"""
46- function gen_by_access (expr_gen, a:: Type{<:StaticVecOrMat} , asym = :a )
30+ function gen_by_access (expr_gen, a:: Type{<:StaticVecOrMat} , asym = :wrapped_a )
4731 return expr_gen (:any )
4832end
49- function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , asym = :a )
33+ function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
5034 return quote
5135 if $ (asym). uplo == ' U'
5236 $ (expr_gen (:up ))
@@ -55,7 +39,7 @@ function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, as
5539 end
5640 end
5741end
58- function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , asym = :a )
42+ function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
5943 return quote
6044 if $ (asym). uplo == ' U'
6145 $ (expr_gen (:up_herm ))
@@ -64,25 +48,22 @@ function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, as
6448 end
6549 end
6650end
67- function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , asym = :a )
51+ function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
6852 return expr_gen (:upper_triangular )
6953end
70- function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :a )
54+ function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
7155 return expr_gen (:lower_triangular )
7256end
73- function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , asym = :a )
57+ function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
7458 return expr_gen (:unit_upper_triangular )
7559end
76- function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , asym = :a )
60+ function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
7761 return expr_gen (:unit_lower_triangular )
7862end
79- function gen_by_access (expr_gen, a:: Type{<:UpperHessenberg{<:Any, <:StaticMatrix}} , asym = :a )
80- return expr_gen (:upper_hessenberg )
81- end
82- function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticVecOrMat}} , asym = :a )
63+ function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticVecOrMat}} , asym = :wrapped_a )
8364 return expr_gen (:transpose )
8465end
85- function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticVecOrMat}} , asym = :a )
66+ function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticVecOrMat}} , asym = :wrapped_a )
8667 return expr_gen (:adjoint )
8768end
8869"""
@@ -94,82 +75,75 @@ first for matrix `a` and the second for matrix `b`.
9475"""
9576function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type )
9677 return quote
97- return $ (gen_by_access (b, :b ) do access_b
78+ return $ (gen_by_access (b, :wrapped_b ) do access_b
9879 expr_gen (:any , access_b)
9980 end )
10081 end
10182end
10283function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , b:: Type )
10384 return quote
104- if a . uplo == ' U'
105- return $ (gen_by_access (b, :b ) do access_b
85+ if wrapped_a . uplo == ' U'
86+ return $ (gen_by_access (b, :wrapped_b ) do access_b
10687 expr_gen (:up , access_b)
10788 end )
10889 else
109- return $ (gen_by_access (b, :b ) do access_b
90+ return $ (gen_by_access (b, :wrapped_b ) do access_b
11091 expr_gen (:lo , access_b)
11192 end )
11293 end
11394 end
11495end
11596function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , b:: Type )
11697 return quote
117- if a . uplo == ' U'
118- return $ (gen_by_access (b, :b ) do access_b
98+ if wrapped_a . uplo == ' U'
99+ return $ (gen_by_access (b, :wrapped_b ) do access_b
119100 expr_gen (:up_herm , access_b)
120101 end )
121102 else
122- return $ (gen_by_access (b, :b ) do access_b
103+ return $ (gen_by_access (b, :wrapped_b ) do access_b
123104 expr_gen (:lo_herm , access_b)
124105 end )
125106 end
126107 end
127108end
128109function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
129110 return quote
130- return $ (gen_by_access (b, :b ) do access_b
111+ return $ (gen_by_access (b, :wrapped_b ) do access_b
131112 expr_gen (:upper_triangular , access_b)
132113 end )
133114 end
134115end
135116function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
136117 return quote
137- return $ (gen_by_access (b, :b ) do access_b
118+ return $ (gen_by_access (b, :wrapped_b ) do access_b
138119 expr_gen (:lower_triangular , access_b)
139120 end )
140121 end
141122end
142123function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
143124 return quote
144- return $ (gen_by_access (b, :b ) do access_b
125+ return $ (gen_by_access (b, :wrapped_b ) do access_b
145126 expr_gen (:unit_upper_triangular , access_b)
146127 end )
147128 end
148129end
149130function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
150131 return quote
151- return $ (gen_by_access (b, :b ) do access_b
132+ return $ (gen_by_access (b, :wrapped_b ) do access_b
152133 expr_gen (:unit_lower_triangular , access_b)
153134 end )
154135 end
155136end
156- function gen_by_access (expr_gen, a:: Type{<:UpperHessenberg{<:Any, <:StaticMatrix}} , b:: Type )
157- return quote
158- return $ (gen_by_access (b, :b ) do access_b
159- expr_gen (:upper_hessenberg , access_b)
160- end )
161- end
162- end
163137function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticMatrix}} , b:: Type )
164138 return quote
165- return $ (gen_by_access (b, :b ) do access_b
139+ return $ (gen_by_access (b, :wrapped_b ) do access_b
166140 expr_gen (:transpose , access_b)
167141 end )
168142 end
169143end
170144function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticMatrix}} , b:: Type )
171145 return quote
172- return $ (gen_by_access (b, :b ) do access_b
146+ return $ (gen_by_access (b, :wrapped_b ) do access_b
173147 expr_gen (:adjoint , access_b)
174148 end )
175149 end
@@ -200,65 +174,74 @@ statically known for this function to work. `uplo` is the access pattern mode ge
200174by the `gen_by_access` function.
201175"""
202176function uplo_access (sa, asym, k, j, uplo)
177+ TAsym = Symbol (" T" * string (asym))
203178 if uplo == :any
204179 return :($ asym[$ (LinearIndices (sa)[k, j])])
205180 elseif uplo == :up
206- if k <= j
181+ if k < j
207182 return :($ asym[$ (LinearIndices (sa)[k, j])])
183+ elseif k == j
184+ return :(LinearAlgebra. symmetric ($ asym[$ (LinearIndices (sa)[k, j])], :U ))
208185 else
209- return :($ asym[$ (LinearIndices (sa)[j, k])])
186+ return :(transpose ( $ asym[$ (LinearIndices (sa)[j, k])]) )
210187 end
211188 elseif uplo == :lo
212- if k >= j
189+ if k > j
213190 return :($ asym[$ (LinearIndices (sa)[k, j])])
191+ elseif k == j
192+ return :(LinearAlgebra. symmetric ($ asym[$ (LinearIndices (sa)[k, j])], :L ))
214193 else
215- return :($ asym[$ (LinearIndices (sa)[j, k])])
194+ return :(transpose ( $ asym[$ (LinearIndices (sa)[j, k])]) )
216195 end
217196 elseif uplo == :up_herm
218- if k <= j
197+ if k < j
219198 return :($ asym[$ (LinearIndices (sa)[k, j])])
199+ elseif k == j
200+ return :(LinearAlgebra. hermitian ($ asym[$ (LinearIndices (sa)[k, j])], :U ))
220201 else
221202 return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
222203 end
223204 elseif uplo == :lo_herm
224- if k >= j
205+ if k > j
225206 return :($ asym[$ (LinearIndices (sa)[k, j])])
207+ elseif k == j
208+ return :(LinearAlgebra. hermitian ($ asym[$ (LinearIndices (sa)[k, j])], :L ))
226209 else
227210 return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
228211 end
229212 elseif uplo == :upper_triangular
230213 if k <= j
231214 return :($ asym[$ (LinearIndices (sa)[k, j])])
232215 else
233- return :(zero (T ))
216+ return :(zero ($ TAsym ))
234217 end
235218 elseif uplo == :lower_triangular
236219 if k >= j
237220 return :($ asym[$ (LinearIndices (sa)[k, j])])
238221 else
239- return :(zero (T ))
222+ return :(zero ($ TAsym ))
240223 end
241224 elseif uplo == :unit_upper_triangular
242225 if k < j
243226 return :($ asym[$ (LinearIndices (sa)[k, j])])
244227 elseif k == j
245- return :(oneunit (T ))
228+ return :(oneunit ($ TAsym ))
246229 else
247- return :(zero (T ))
230+ return :(zero ($ TAsym ))
248231 end
249232 elseif uplo == :unit_lower_triangular
250233 if k > j
251234 return :($ asym[$ (LinearIndices (sa)[k, j])])
252235 elseif k == j
253- return :(oneunit (T ))
236+ return :(oneunit ($ TAsym ))
254237 else
255- return :(zero (T ))
238+ return :(zero ($ TAsym ))
256239 end
257240 elseif uplo == :upper_hessenberg
258241 if k <= j+ 1
259242 return :($ asym[$ (LinearIndices (sa)[k, j])])
260243 else
261- return :(zero (T ))
244+ return :(zero ($ TAsym ))
262245 end
263246 elseif uplo == :transpose
264247 return :($ asym[$ (LinearIndices (reverse (sa))[j, k])])
@@ -273,9 +256,9 @@ function mul_smat_vec_exprs(sa, access_a)
273256 return [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:($ (uplo_access (sa, :a , k, j, access_a))* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
274257end
275258
276- @generated function _mul (:: Size{sa} , a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: AbstractVector{Tb} ) where {sa, Ta, Tb}
259+ @generated function _mul (:: Size{sa} , wrapped_a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: AbstractVector{Tb} ) where {sa, Ta, Tb}
277260 if sa[2 ] != 0
278- retexpr = gen_by_access (a ) do access_a
261+ retexpr = gen_by_access (wrapped_a ) do access_a
279262 exprs = mul_smat_vec_exprs (sa, access_a)
280263 return :(@inbounds return similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
281264 end
@@ -290,17 +273,18 @@ end
290273 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $(size (b)) " ))
291274 end
292275 T = promote_op (matprod,Ta,Tb)
276+ a = mul_parent (wrapped_a)
293277 $ retexpr
294278 end
295279end
296280
297- @generated function _mul (:: Size{sa} , :: Size{sb} , a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticVector{<:Any, Tb} ) where {sa, sb, Ta, Tb}
281+ @generated function _mul (:: Size{sa} , :: Size{sb} , wrapped_a :: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticVector{<:Any, Tb} ) where {sa, sb, Ta, Tb}
298282 if sb[1 ] != sa[2 ]
299283 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
300284 end
301285
302286 if sa[2 ] != 0
303- retexpr = gen_by_access (a ) do access_a
287+ retexpr = gen_by_access (wrapped_a ) do access_a
304288 exprs = mul_smat_vec_exprs (sa, access_a)
305289 return :(@inbounds similar_type (b, T, Size (sa[1 ]))(tuple ($ (exprs... ))))
306290 end
312296 return quote
313297 @_inline_meta
314298 T = promote_op (matprod,Ta,Tb)
299+ a = mul_parent (wrapped_a)
315300 $ retexpr
316301 end
317302end
@@ -362,28 +347,30 @@ end
362347 end
363348end
364349
365- @generated function mul_unrolled (:: Size{sa} , :: Size{sb} , a :: StaticMatMulLike{<:Any, <:Any, Ta} , b :: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
350+ @generated function mul_unrolled (:: Size{sa} , :: Size{sb} , wrapped_a :: StaticMatMulLike{<:Any, <:Any, Ta} , wrapped_b :: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
366351 if sb[1 ] != sa[2 ]
367352 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
368353 end
369354
370355 S = Size (sa[1 ], sb[2 ])
371356
372357 if sa[2 ] != 0
373- retexpr = gen_by_access (a, b ) do access_a, access_b
358+ retexpr = gen_by_access (wrapped_a, wrapped_b ) do access_a, access_b
374359 exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)),
375360 [:($ (uplo_access (sa, :a , k1, j, access_a))* $ (uplo_access (sb, :b , j, k2, access_b))) for j = 1 : sa[2 ]]
376361 ) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
377- return :((mul_result_structure (a, b ))(similar_type (a, T, $ S)(tuple ($ (exprs... )))))
362+ return :((mul_result_structure (wrapped_a, wrapped_b ))(similar_type (a, T, $ S)(tuple ($ (exprs... )))))
378363 end
379364 else
380365 exprs = [:(zero (T)) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
381- retexpr = :(return (mul_result_structure (a, b ))(similar_type (a, T, $ S)(tuple ($ (exprs... )))))
366+ retexpr = :(return (mul_result_structure (wrapped_a, wrapped_b ))(similar_type (a, T, $ S)(tuple ($ (exprs... )))))
382367 end
383368
384369 return quote
385370 @_inline_meta
386371 T = promote_op (matprod,Ta,Tb)
372+ a = mul_parent (wrapped_a)
373+ b = mul_parent (wrapped_b)
387374 @inbounds $ retexpr
388375 end
389376end
0 commit comments