@@ -15,150 +15,6 @@ import LinearAlgebra: BlasFloat, matprod, mul!
1515@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Adjoint{<:Any,<:StaticVector} ) where {N} = vec (A) * B
1616@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
1717
18- """
19- gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a)
20-
21- Statically generate outer code for fully unrolled multiplication loops.
22- Returned code does wrapper-specific tests (for example if a symmetric matrix view is
23- `U` or `L`) and the body of the if expression is then generated by function `expr_gen`.
24- The function `expr_gen` receives access pattern description symbol as its argument
25- and this symbol is then consumed by uplo_access to generate the right code for matrix
26- element access.
27-
28- The name of the matrix to test is indicated by `asym`.
29- """
30- function gen_by_access (expr_gen, a:: Type{<:StaticVecOrMat} , asym = :wrapped_a )
31- return expr_gen (:any )
32- end
33- function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
34- return quote
35- if $ (asym). uplo == ' U'
36- $ (expr_gen (:up ))
37- else
38- $ (expr_gen (:lo ))
39- end
40- end
41- end
42- function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
43- return quote
44- if $ (asym). uplo == ' U'
45- $ (expr_gen (:up_herm ))
46- else
47- $ (expr_gen (:lo_herm ))
48- end
49- end
50- end
51- function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
52- return expr_gen (:upper_triangular )
53- end
54- function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
55- return expr_gen (:lower_triangular )
56- end
57- function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
58- return expr_gen (:unit_upper_triangular )
59- end
60- function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , asym = :wrapped_a )
61- return expr_gen (:unit_lower_triangular )
62- end
63- function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticVecOrMat}} , asym = :wrapped_a )
64- return expr_gen (:transpose )
65- end
66- function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticVecOrMat}} , asym = :wrapped_a )
67- return expr_gen (:adjoint )
68- end
69- function gen_by_access (expr_gen, a:: Type{<:SDiagonal} , asym = :wrapped_a )
70- return expr_gen (:diagonal )
71- end
72- """
73- gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray})
74-
75- Simiar to gen_by_access with only one type argument. The difference is that tests for both
76- arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments,
77- first for matrix `a` and the second for matrix `b`.
78- """
79- function gen_by_access (expr_gen, a:: Type{<:StaticMatrix} , b:: Type )
80- return quote
81- return $ (gen_by_access (b, :wrapped_b ) do access_b
82- expr_gen (:any , access_b)
83- end )
84- end
85- end
86- function gen_by_access (expr_gen, a:: Type{<:Symmetric{<:Any, <:StaticMatrix}} , b:: Type )
87- return quote
88- if wrapped_a. uplo == ' U'
89- return $ (gen_by_access (b, :wrapped_b ) do access_b
90- expr_gen (:up , access_b)
91- end )
92- else
93- return $ (gen_by_access (b, :wrapped_b ) do access_b
94- expr_gen (:lo , access_b)
95- end )
96- end
97- end
98- end
99- function gen_by_access (expr_gen, a:: Type{<:Hermitian{<:Any, <:StaticMatrix}} , b:: Type )
100- return quote
101- if wrapped_a. uplo == ' U'
102- return $ (gen_by_access (b, :wrapped_b ) do access_b
103- expr_gen (:up_herm , access_b)
104- end )
105- else
106- return $ (gen_by_access (b, :wrapped_b ) do access_b
107- expr_gen (:lo_herm , access_b)
108- end )
109- end
110- end
111- end
112- function gen_by_access (expr_gen, a:: Type{<:UpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
113- return quote
114- return $ (gen_by_access (b, :wrapped_b ) do access_b
115- expr_gen (:upper_triangular , access_b)
116- end )
117- end
118- end
119- function gen_by_access (expr_gen, a:: Type{<:LowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
120- return quote
121- return $ (gen_by_access (b, :wrapped_b ) do access_b
122- expr_gen (:lower_triangular , access_b)
123- end )
124- end
125- end
126- function gen_by_access (expr_gen, a:: Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}} , b:: Type )
127- return quote
128- return $ (gen_by_access (b, :wrapped_b ) do access_b
129- expr_gen (:unit_upper_triangular , access_b)
130- end )
131- end
132- end
133- function gen_by_access (expr_gen, a:: Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}} , b:: Type )
134- return quote
135- return $ (gen_by_access (b, :wrapped_b ) do access_b
136- expr_gen (:unit_lower_triangular , access_b)
137- end )
138- end
139- end
140- function gen_by_access (expr_gen, a:: Type{<:Transpose{<:Any, <:StaticMatrix}} , b:: Type )
141- return quote
142- return $ (gen_by_access (b, :wrapped_b ) do access_b
143- expr_gen (:transpose , access_b)
144- end )
145- end
146- end
147- function gen_by_access (expr_gen, a:: Type{<:Adjoint{<:Any, <:StaticMatrix}} , b:: Type )
148- return quote
149- return $ (gen_by_access (b, :wrapped_b ) do access_b
150- expr_gen (:adjoint , access_b)
151- end )
152- end
153- end
154- function gen_by_access (expr_gen, a:: Type{<:SDiagonal} , b:: Type )
155- return quote
156- return $ (gen_by_access (b, :wrapped_b ) do access_b
157- expr_gen (:diagonal , access_b)
158- end )
159- end
160- end
161-
16218"""
16319 mul_result_structure(a::Type, b::Type)
16420
@@ -202,99 +58,6 @@ function mul_result_structure(::SDiagonal, ::SDiagonal)
20258 return Diagonal
20359end
20460
205- """
206- uplo_access(sa, asym, k, j, uplo)
207-
208- Generate code for matrix element access, for a matrix of size `sa` locally referred to
209- as `asym` in the context where the result will be used. Both indices `k` and `j` need to be
210- statically known for this function to work. `uplo` is the access pattern mode generated
211- by the `gen_by_access` function.
212- """
213- function uplo_access (sa, asym, k, j, uplo)
214- TAsym = Symbol (" T" * string (asym))
215- if uplo == :any
216- return :($ asym[$ (LinearIndices (sa)[k, j])])
217- elseif uplo == :up
218- if k < j
219- return :($ asym[$ (LinearIndices (sa)[k, j])])
220- elseif k == j
221- return :(LinearAlgebra. symmetric ($ asym[$ (LinearIndices (sa)[k, j])], :U ))
222- else
223- return :(transpose ($ asym[$ (LinearIndices (sa)[j, k])]))
224- end
225- elseif uplo == :lo
226- if k > j
227- return :($ asym[$ (LinearIndices (sa)[k, j])])
228- elseif k == j
229- return :(LinearAlgebra. symmetric ($ asym[$ (LinearIndices (sa)[k, j])], :L ))
230- else
231- return :(transpose ($ asym[$ (LinearIndices (sa)[j, k])]))
232- end
233- elseif uplo == :up_herm
234- if k < j
235- return :($ asym[$ (LinearIndices (sa)[k, j])])
236- elseif k == j
237- return :(LinearAlgebra. hermitian ($ asym[$ (LinearIndices (sa)[k, j])], :U ))
238- else
239- return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
240- end
241- elseif uplo == :lo_herm
242- if k > j
243- return :($ asym[$ (LinearIndices (sa)[k, j])])
244- elseif k == j
245- return :(LinearAlgebra. hermitian ($ asym[$ (LinearIndices (sa)[k, j])], :L ))
246- else
247- return :(adjoint ($ asym[$ (LinearIndices (sa)[j, k])]))
248- end
249- elseif uplo == :upper_triangular
250- if k <= j
251- return :($ asym[$ (LinearIndices (sa)[k, j])])
252- else
253- return :(zero ($ TAsym))
254- end
255- elseif uplo == :lower_triangular
256- if k >= j
257- return :($ asym[$ (LinearIndices (sa)[k, j])])
258- else
259- return :(zero ($ TAsym))
260- end
261- elseif uplo == :unit_upper_triangular
262- if k < j
263- return :($ asym[$ (LinearIndices (sa)[k, j])])
264- elseif k == j
265- return :(oneunit ($ TAsym))
266- else
267- return :(zero ($ TAsym))
268- end
269- elseif uplo == :unit_lower_triangular
270- if k > j
271- return :($ asym[$ (LinearIndices (sa)[k, j])])
272- elseif k == j
273- return :(oneunit ($ TAsym))
274- else
275- return :(zero ($ TAsym))
276- end
277- elseif uplo == :upper_hessenberg
278- if k <= j+ 1
279- return :($ asym[$ (LinearIndices (sa)[k, j])])
280- else
281- return :(zero ($ TAsym))
282- end
283- elseif uplo == :transpose
284- return :(transpose ($ asym[$ (LinearIndices (reverse (sa))[j, k])]))
285- elseif uplo == :adjoint
286- return :(adjoint ($ asym[$ (LinearIndices (reverse (sa))[j, k])]))
287- elseif uplo == :diagonal
288- if k == j
289- return :($ asym[$ k])
290- else
291- return :(zero ($ TAsym))
292- end
293- else
294- error (" Unknown uplo: $uplo " )
295- end
296- end
297-
29861# Implementations
29962
30063function mul_smat_vec_exprs (sa, access_a)
@@ -369,31 +132,6 @@ for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTria
369132 @eval _unstatic_array (:: Type{$TWR{T,TSA}} ) where {S, T, N, TSA<: StaticArray{S,T,N} } = $ TWR{T,<: AbstractArray{T,N} }
370133end
371134
372- function combine_products (expr_list)
373- filtered = filter (expr_list) do expr
374- if expr. head != :call || expr. args[1 ] != :*
375- error (" expected call to *" )
376- end
377- for arg in expr. args[2 : end ]
378- if isa (arg, Expr) && arg. head == :call && arg. args[1 ] == :zero
379- return false
380- end
381- end
382- return true
383- end
384- if isempty (filtered)
385- return :(zero (T))
386- else
387- return reduce (filtered) do ex1, ex2
388- if ex2. head != :call || ex2. args[1 ] != :*
389- error (" expected call to *" )
390- end
391-
392- return :(muladd ($ (ex2. args[2 ]), $ (ex2. args[3 ]), $ ex1))
393- end
394- end
395- end
396-
397135@generated function _mul (Sa:: Size{sa} , Sb:: Size{sb} , a:: StaticMatMulLike{<:Any, <:Any, Ta} , b:: StaticMatMulLike{<:Any, <:Any, Tb} ) where {sa, sb, Ta, Tb}
398136 S = Size (sa[1 ], sb[2 ])
399137 # Heuristic choice for amount of codegen
0 commit comments