11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- # fmt: off
43# ruff: noqa: E501
54import time
65
2019)
2120
2221
23- def benchmark_shape (m : int ,
24- n : int ,
25- k : int ,
26- warmup : int = 100 ,
27- repeat : int = 10000 ,
28- verbose : bool = False ) -> dict :
22+ def benchmark_shape (
23+ m : int ,
24+ n : int ,
25+ k : int ,
26+ warmup : int = 100 ,
27+ repeat : int = 10000 ,
28+ verbose : bool = False ,
29+ ) -> dict :
2930 """Benchmark all implementations for a specific (m, n, k) shape."""
3031 if verbose :
3132 print (f"\n === Benchmarking shape: m={ m } , n={ n } , k={ k } ===" )
3233
3334 # Create test tensors
34- A = torch .randn ((m , k ), device = ' cuda' , dtype = torch .bfloat16 )
35- B = torch .randn ((n , k ), device = ' cuda' , dtype = torch .bfloat16 )
35+ A = torch .randn ((m , k ), device = " cuda" , dtype = torch .bfloat16 )
36+ B = torch .randn ((n , k ), device = " cuda" , dtype = torch .bfloat16 )
3637
3738 # Reference result in BF16
3839 torch .cuda .synchronize ()
@@ -49,34 +50,39 @@ def benchmark_shape(m: int,
4950 # Pre-quantize A for all implementations
5051 A_deepgemm , A_scale_deepgemm = per_token_group_quant_fp8 (A , block_size [1 ])
5152 A_scale_deepgemm = get_col_major_tma_aligned_tensor (A_scale_deepgemm )
52- C_deepgemm = torch .empty ((m , n ), device = ' cuda' , dtype = torch .bfloat16 )
53+ C_deepgemm = torch .empty ((m , n ), device = " cuda" , dtype = torch .bfloat16 )
5354 A_vllm , A_scale_vllm = per_token_group_quant_fp8 (A , block_size [1 ])
5455 A_vllm_cutlass , A_scale_vllm_cutlass = per_token_group_quant_fp8 (
55- A , block_size [1 ], column_major_scales = True )
56+ A , block_size [1 ], column_major_scales = True
57+ )
5658
5759 # === DeepGEMM Implementation ===
5860 def deepgemm_gemm ():
59- fp8_gemm_nt (( A_deepgemm , A_scale_deepgemm ),
60- (B_deepgemm , B_scale_deepgemm ),
61- C_deepgemm )
61+ fp8_gemm_nt (
62+ ( A_deepgemm , A_scale_deepgemm ), (B_deepgemm , B_scale_deepgemm ), C_deepgemm
63+ )
6264 return C_deepgemm
6365
6466 # === vLLM Triton Implementation ===
6567 def vllm_triton_gemm ():
66- return w8a8_triton_block_scaled_mm (A_vllm ,
67- B_vllm ,
68- A_scale_vllm ,
69- B_scale_vllm ,
70- block_size ,
71- output_dtype = torch .bfloat16 )
68+ return w8a8_triton_block_scaled_mm (
69+ A_vllm ,
70+ B_vllm ,
71+ A_scale_vllm ,
72+ B_scale_vllm ,
73+ block_size ,
74+ output_dtype = torch .bfloat16 ,
75+ )
7276
7377 # === vLLM CUTLASS Implementation ===
7478 def vllm_cutlass_gemm ():
75- return ops .cutlass_scaled_mm (A_vllm_cutlass ,
76- B_vllm .T ,
77- scale_a = A_scale_vllm_cutlass ,
78- scale_b = B_scale_vllm .T ,
79- out_dtype = torch .bfloat16 )
79+ return ops .cutlass_scaled_mm (
80+ A_vllm_cutlass ,
81+ B_vllm .T ,
82+ scale_a = A_scale_vllm_cutlass ,
83+ scale_b = B_scale_vllm .T ,
84+ out_dtype = torch .bfloat16 ,
85+ )
8086
8187 # Run correctness check first
8288 if verbose :
@@ -93,26 +99,23 @@ def vllm_cutlass_gemm():
9399 print (f"DeepGEMM vs Reference difference: { deepgemm_diff :.6f} " )
94100 print (f"vLLM Triton vs Reference difference: { vllm_triton_diff :.6f} " )
95101 print (f"vLLM CUTLASS vs Reference difference: { vllm_cutlass_diff :.6f} " )
96- print ("vLLM Triton vs DeepGEMM difference: "
97- f"{ calc_diff (C_vllm_triton , C_deepgemm ):.6f} " )
98- print ("vLLM CUTLASS vs DeepGEMM difference: "
99- f"{ calc_diff (C_vllm_cutlass , C_deepgemm ):.6f} " )
102+ print (
103+ "vLLM Triton vs DeepGEMM difference: "
104+ f"{ calc_diff (C_vllm_triton , C_deepgemm ):.6f} "
105+ )
106+ print (
107+ "vLLM CUTLASS vs DeepGEMM difference: "
108+ f"{ calc_diff (C_vllm_cutlass , C_deepgemm ):.6f} "
109+ )
100110
101111 # Benchmark implementations
102112 implementations = {
103113 "DeepGEMM" : deepgemm_gemm ,
104114 "vLLM Triton" : vllm_triton_gemm ,
105- "vLLM CUTLASS" : vllm_cutlass_gemm
115+ "vLLM CUTLASS" : vllm_cutlass_gemm ,
106116 }
107117
108- benchmark_results = {
109- "shape" : {
110- "m" : m ,
111- "n" : n ,
112- "k" : k
113- },
114- "implementations" : {}
115- }
118+ benchmark_results = {"shape" : {"m" : m , "n" : n , "k" : k }, "implementations" : {}}
116119
117120 for name , func in implementations .items ():
118121 # Warmup
@@ -140,38 +143,36 @@ def vllm_cutlass_gemm():
140143 "tflops" : tflops ,
141144 "gb_s" : gb_s ,
142145 "diff" : {
143- "DeepGEMM" :
144- 0.0 if name == "DeepGEMM" else calc_diff ( func (), C_deepgemm ),
145- "Reference" :
146- deepgemm_diff if name == "DeepGEMM" else
147- ( vllm_triton_diff
148- if name == "vLLM Triton" else vllm_cutlass_diff )
149- }
146+ "DeepGEMM" : 0.0
147+ if name == "DeepGEMM"
148+ else calc_diff ( func (), C_deepgemm ),
149+ "Reference" : deepgemm_diff
150+ if name == "DeepGEMM"
151+ else ( vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff ),
152+ },
150153 }
151154
152155 if verbose :
153- print (
154- f"{ name } : { avg_time_ms :.3f} ms, { tflops :.2f} TFLOPS, { gb_s :.2f} GB/s"
155- )
156+ print (f"{ name } : { avg_time_ms :.3f} ms, { tflops :.2f} TFLOPS, { gb_s :.2f} GB/s" )
156157
157158 # Calculate speedups
158159 baseline = benchmark_results ["implementations" ]["DeepGEMM" ]["time_ms" ]
159160 for name , data in benchmark_results ["implementations" ].items ():
160161 if name != "DeepGEMM" :
161162 speedup = baseline / data ["time_ms" ]
162- benchmark_results ["implementations" ][name ][
163- "speedup_vs_deepgemm" ] = speedup
163+ benchmark_results ["implementations" ][name ]["speedup_vs_deepgemm" ] = speedup
164164 if verbose :
165- print (f"DeepGEMM is { 1 / speedup :.2f} x "
166- f"{ 'faster' if 1 / speedup > 1 else 'slower' } than { name } " )
165+ print (
166+ f"DeepGEMM is { 1 / speedup :.2f} x "
167+ f"{ 'faster' if 1 / speedup > 1 else 'slower' } than { name } "
168+ )
167169
168- vllm_triton_time = benchmark_results ["implementations" ]["vLLM Triton" ][
169- "time_ms" ]
170- vllm_cutlass_time = benchmark_results ["implementations" ]["vLLM CUTLASS" ][
171- "time_ms" ]
170+ vllm_triton_time = benchmark_results ["implementations" ]["vLLM Triton" ]["time_ms" ]
171+ vllm_cutlass_time = benchmark_results ["implementations" ]["vLLM CUTLASS" ]["time_ms" ]
172172 cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
173- benchmark_results ["implementations" ]["vLLM CUTLASS" ][
174- "speedup_vs_triton" ] = cutlass_vs_triton
173+ benchmark_results ["implementations" ]["vLLM CUTLASS" ]["speedup_vs_triton" ] = (
174+ cutlass_vs_triton
175+ )
175176 if verbose :
176177 print (
177178 f"vLLM CUTLASS is { cutlass_vs_triton :.2f} x "
@@ -183,8 +184,7 @@ def vllm_cutlass_gemm():
183184
184185def format_table_row (values , widths ):
185186 """Format a row with specified column widths."""
186- return "| " + " | " .join (f"{ val :{w }} "
187- for val , w in zip (values , widths )) + " |"
187+ return "| " + " | " .join (f"{ val :{w }} " for val , w in zip (values , widths )) + " |"
188188
189189
190190def print_table (headers , rows , title = None ):
@@ -292,67 +292,78 @@ def run_benchmarks(verbose: bool = False):
292292 for result in all_results :
293293 shape = result ["shape" ]
294294 impl_data = result ["implementations" ]["DeepGEMM" ]
295- deepgemm_rows .append ([
296- shape ["m" ], shape ["n" ], shape ["k" ], f"{ impl_data ['time_us' ]:.1f} " ,
297- f"{ impl_data ['tflops' ]:.1f} " , f"{ impl_data ['gb_s' ]:.1f} "
298- ])
295+ deepgemm_rows .append (
296+ [
297+ shape ["m" ],
298+ shape ["n" ],
299+ shape ["k" ],
300+ f"{ impl_data ['time_us' ]:.1f} " ,
301+ f"{ impl_data ['tflops' ]:.1f} " ,
302+ f"{ impl_data ['gb_s' ]:.1f} " ,
303+ ]
304+ )
299305
300- print_table (deepgemm_headers ,
301- deepgemm_rows ,
302- title = "DeepGEMM Implementation:" )
306+ print_table (deepgemm_headers , deepgemm_rows , title = "DeepGEMM Implementation:" )
303307
304308 # Print vLLM Triton table
305- triton_headers = [
306- "m" , "n" , "k" , "Time (μs)" , "TFLOPS" , "GB/s" , "vs DeepGEMM"
307- ]
309+ triton_headers = ["m" , "n" , "k" , "Time (μs)" , "TFLOPS" , "GB/s" , "vs DeepGEMM" ]
308310 triton_rows = []
309311 for result in all_results :
310312 shape = result ["shape" ]
311313 impl_data = result ["implementations" ]["vLLM Triton" ]
312314 speedup = impl_data .get ("speedup_vs_deepgemm" , 1.0 )
313- triton_rows .append ([
314- shape ["m" ], shape ["n" ], shape ["k" ], f"{ impl_data ['time_us' ]:.1f} " ,
315- f"{ impl_data ['tflops' ]:.1f} " , f"{ impl_data ['gb_s' ]:.1f} " ,
316- format_speedup (speedup )
317- ])
315+ triton_rows .append (
316+ [
317+ shape ["m" ],
318+ shape ["n" ],
319+ shape ["k" ],
320+ f"{ impl_data ['time_us' ]:.1f} " ,
321+ f"{ impl_data ['tflops' ]:.1f} " ,
322+ f"{ impl_data ['gb_s' ]:.1f} " ,
323+ format_speedup (speedup ),
324+ ]
325+ )
318326
319- print_table (triton_headers ,
320- triton_rows ,
321- title = "vLLM Triton Implementation:" )
327+ print_table (triton_headers , triton_rows , title = "vLLM Triton Implementation:" )
322328
323329 # Print vLLM CUTLASS table
324330 cutlass_headers = [
325- "m" , "n" , "k" , "Time (μs)" , "TFLOPS" , "GB/s" , "vs DeepGEMM" ,
326- "vs Triton"
331+ "m" ,
332+ "n" ,
333+ "k" ,
334+ "Time (μs)" ,
335+ "TFLOPS" ,
336+ "GB/s" ,
337+ "vs DeepGEMM" ,
338+ "vs Triton" ,
327339 ]
328340 cutlass_rows = []
329341 for result in all_results :
330342 shape = result ["shape" ]
331343 impl_data = result ["implementations" ]["vLLM CUTLASS" ]
332344 vs_deepgemm = impl_data .get ("speedup_vs_deepgemm" , 1.0 )
333345 vs_triton = impl_data .get ("speedup_vs_triton" , 1.0 )
334- cutlass_rows .append ([
335- shape ["m" ], shape ["n" ], shape ["k" ], f"{ impl_data ['time_us' ]:.1f} " ,
336- f"{ impl_data ['tflops' ]:.1f} " , f"{ impl_data ['gb_s' ]:.1f} " ,
337- format_speedup (vs_deepgemm ),
338- format_speedup (vs_triton )
339- ])
346+ cutlass_rows .append (
347+ [
348+ shape ["m" ],
349+ shape ["n" ],
350+ shape ["k" ],
351+ f"{ impl_data ['time_us' ]:.1f} " ,
352+ f"{ impl_data ['tflops' ]:.1f} " ,
353+ f"{ impl_data ['gb_s' ]:.1f} " ,
354+ format_speedup (vs_deepgemm ),
355+ format_speedup (vs_triton ),
356+ ]
357+ )
340358
341- print_table (cutlass_headers ,
342- cutlass_rows ,
343- title = "vLLM CUTLASS Implementation:" )
359+ print_table (cutlass_headers , cutlass_rows , title = "vLLM CUTLASS Implementation:" )
344360
345361 # Calculate and print averages
346362 print ("\n ===== AVERAGE PERFORMANCE =====" )
347363
348364 implementations = ["DeepGEMM" , "vLLM Triton" , "vLLM CUTLASS" ]
349365 avg_metrics = {
350- impl : {
351- "tflops" : 0 ,
352- "gb_s" : 0 ,
353- "time_ms" : 0
354- }
355- for impl in implementations
366+ impl : {"tflops" : 0 , "gb_s" : 0 , "time_ms" : 0 } for impl in implementations
356367 }
357368
358369 for result in all_results :
@@ -370,31 +381,29 @@ def run_benchmarks(verbose: bool = False):
370381 avg_tflops = avg_metrics [impl ]["tflops" ] / num_shapes
371382 avg_mem_bw = avg_metrics [impl ]["gb_s" ] / num_shapes
372383 avg_time = avg_metrics [impl ]["time_ms" ] / num_shapes
373- avg_rows .append ([
374- impl , f"{ avg_tflops :.2f} " , f"{ avg_mem_bw :.2f} " , f"{ avg_time :.2f} "
375- ] )
384+ avg_rows .append (
385+ [ impl , f"{ avg_tflops :.2f} " , f"{ avg_mem_bw :.2f} " , f"{ avg_time :.2f} " ]
386+ )
376387
377388 print_table (avg_headers , avg_rows )
378389
379390 # Calculate average speedups
380391 avg_speedups = {
381392 "DeepGEMM vs vLLM Triton" : 0 ,
382393 "DeepGEMM vs vLLM CUTLASS" : 0 ,
383- "vLLM CUTLASS vs vLLM Triton" : 0
394+ "vLLM CUTLASS vs vLLM Triton" : 0 ,
384395 }
385396
386397 for result in all_results :
387398 deepgemm_time = result ["implementations" ]["DeepGEMM" ]["time_ms" ]
388399 vllm_triton_time = result ["implementations" ]["vLLM Triton" ]["time_ms" ]
389- vllm_cutlass_time = result ["implementations" ]["vLLM CUTLASS" ][
390- "time_ms" ]
400+ vllm_cutlass_time = result ["implementations" ]["vLLM CUTLASS" ]["time_ms" ]
391401
392- avg_speedups [
393- "DeepGEMM vs vLLM Triton" ] += vllm_triton_time / deepgemm_time
394- avg_speedups [
395- "DeepGEMM vs vLLM CUTLASS" ] += vllm_cutlass_time / deepgemm_time
396- avg_speedups [
397- "vLLM CUTLASS vs vLLM Triton" ] += vllm_triton_time / vllm_cutlass_time
402+ avg_speedups ["DeepGEMM vs vLLM Triton" ] += vllm_triton_time / deepgemm_time
403+ avg_speedups ["DeepGEMM vs vLLM CUTLASS" ] += vllm_cutlass_time / deepgemm_time
404+ avg_speedups ["vLLM CUTLASS vs vLLM Triton" ] += (
405+ vllm_triton_time / vllm_cutlass_time
406+ )
398407
399408 print ("\n ===== AVERAGE SPEEDUPS =====" )
400409 speedup_headers = ["Comparison" , "Speedup" ]
@@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
412421
413422 for result in all_results :
414423 for impl in implementations :
415- avg_diff [impl ] += result ["implementations" ][impl ]["diff" ][
416- "Reference" ]
424+ avg_diff [impl ] += result ["implementations" ][impl ]["diff" ]["Reference" ]
417425
418426 diff_headers = ["Implementation" , "Avg Diff vs Reference" ]
419427 diff_rows = []
0 commit comments