@@ -17,6 +17,13 @@ mul_wrappers = [
1717 (m -> Transpose (m), " transpo" ),
1818 (m -> Diagonal (m), " diag " )]
1919
20+ mul_wrappers_reduced = [
21+ (m -> m, " ident " ),
22+ (m -> Symmetric (m, :U ), " sym-u " ),
23+ (m -> UpperTriangular (m), " up-tri " ),
24+ (m -> Transpose (m), " transpo" ),
25+ (m -> Diagonal (m), " diag " )]
26+
2027for N in [2 , 4 , 8 , 10 , 16 ]
2128
2229 matvecstr = @sprintf (" mat-vec %2d" , N)
@@ -41,7 +48,7 @@ for N in [2, 4, 8, 10, 16]
4148 thrown = true
4249 end
4350 if ! thrown
44- suite[matvecstr][wrapper_name] = @benchmarkable $ (wrapper_a (A)) * $ bv
51+ suite[matvecstr][wrapper_name] = @benchmarkable $ (Ref ( wrapper_a (A)))[] * $ ( Ref (bv))[]
4552 end
4653 end
4754
@@ -53,7 +60,7 @@ for N in [2, 4, 8, 10, 16]
5360 thrown = true
5461 end
5562 if ! thrown
56- suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $ (wrapper_a (A)) * $ (wrapper_b (B))
63+ suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $ (Ref ( wrapper_a (A)))[] * $ (Ref ( wrapper_b (B)))[]
5764 end
5865 end
5966
@@ -68,7 +75,7 @@ for N in [2, 4, 8, 10, 16]
6875 thrown = true
6976 end
7077 if ! thrown
71- suite[matvec_mut_str][wrapper_name] = @benchmarkable mul! ($ cv, $ (wrapper_a (A)), $ bv )
78+ suite[matvec_mut_str][wrapper_name] = @benchmarkable mul! ($ cv, $ (Ref ( wrapper_a (A)))[] , $ ( Ref (bv))[] )
7279 end
7380 end
7481
@@ -80,7 +87,7 @@ for N in [2, 4, 8, 10, 16]
8087 thrown = true
8188 end
8289 if ! thrown
83- suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul! ($ C, $ (wrapper_a (A)), $ (wrapper_b (B)))
90+ suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul! ($ C, $ (Ref ( wrapper_a (A)))[] , $ (Ref ( wrapper_b (B)))[] )
8491 end
8592 end
8693end
@@ -111,3 +118,94 @@ function judge_results(m1, m2)
111118 end
112119 return results
113120end
121+
122+ function generic_mul (size_a, size_b, a, b)
123+ return invoke (* , Tuple{StaticArrays. _unstatic_array (typeof (a)),StaticArrays. _unstatic_array (typeof (b))}, a, b)
124+ end
125+
126+ function full_benchmark (mul_wrappers, size_iter = 1 : 4 , T = Float64)
127+ suite_full = BenchmarkGroup ()
128+ for N in size_iter
129+ for M in size_iter
130+ a = randn (SMatrix{N,M,T})
131+ wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1 ]]
132+ sa = Size (a)
133+ for K in size_iter
134+ b = randn (SMatrix{M,K,T})
135+ wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1 ]]
136+ sb = Size (b)
137+ for (w_a, w_a_name) in wrappers_a
138+ for (w_b, w_b_name) in wrappers_b
139+ cur_str = @sprintf (" mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
140+ suite_full[cur_str] = @benchmarkable generic_mul ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
141+ cur_str = @sprintf (" mat-mat %s %s default (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
142+ suite_full[cur_str] = @benchmarkable StaticArrays. _mul ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
143+ cur_str = @sprintf (" mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
144+ suite_full[cur_str] = @benchmarkable StaticArrays. mul_unrolled ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
145+ if w_a_name != " diag " && w_b_name != " diag "
146+ cur_str = @sprintf (" mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
147+ suite_full[cur_str] = @benchmarkable StaticArrays. mul_unrolled_chunks ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
148+ end
149+ if w_a_name == " ident " && w_b_name == " ident "
150+ cur_str = @sprintf (" mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
151+ suite_full[cur_str] = @benchmarkable StaticArrays. mul_loop ($ sa, $ sb, $ (Ref (w_a (a)))[], $ (Ref (w_b (b)))[])
152+ end
153+ end
154+ end
155+ end
156+ end
157+ end
158+ results = run (suite_full, verbose = true )
159+ results_median = map (collect (results)) do res
160+ return (res[1 ], median (res[2 ]). time)
161+ end
162+ return results_median
163+ end
164+
165+ function judge_this (new_time, old_time, tol, w_a_name, w_b_name, N, M, K, which)
166+ if new_time* tol < old_time
167+ msg = @sprintf (" better for %s %s (%2d, %2d) x (%2d, %2d): %s" , w_a_name, w_b_name, N, M, M, K, which)
168+ println (msg)
169+ println (" >> " , new_time, " | " , old_time)
170+ end
171+ end
172+
173+ function pick_best (results, mul_wrappers, size_iter; tol = 1.2 )
174+ for N in size_iter
175+ for M in size_iter
176+ wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1 ]]
177+ for K in size_iter
178+ wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1 ]]
179+ for (w_a, w_a_name) in wrappers_a
180+ for (w_b, w_b_name) in wrappers_b
181+ cur_default = @sprintf (" mat-mat %s %s default (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
182+ default_time = results[cur_default]
183+
184+ cur_generic = @sprintf (" mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
185+ generic_time = results[cur_generic]
186+ judge_this (generic_time, default_time, tol, w_a_name, w_b_name, N, M, K, " generic" )
187+
188+ cur_unrolled = @sprintf (" mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
189+ unrolled_time = results[cur_unrolled]
190+ judge_this (unrolled_time, default_time, tol, w_a_name, w_b_name, N, M, K, " unrolled" )
191+
192+ if w_a_name != " diag " && w_b_name != " diag "
193+ cur_chunks = @sprintf (" mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
194+ chunk_time = results[cur_chunks]
195+ judge_this (chunk_time, default_time, tol, w_a_name, w_b_name, N, M, K, " chunks" )
196+ end
197+ if w_a_name == " ident " && w_b_name == " ident "
198+ cur_loop = @sprintf (" mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)" , w_a_name, w_b_name, N, M, M, K)
199+ loop_time = results[cur_loop]
200+ judge_this (loop_time, default_time, tol, w_a_name, w_b_name, N, M, K, " loop" )
201+ end
202+ end
203+ end
204+ end
205+ end
206+ end
207+ end
208+
209+ function run_1 ()
210+ return full_benchmark (mul_wrappers_reduced, [2 , 3 , 4 , 5 , 8 , 9 , 14 , 16 ])
211+ end
0 commit comments