11"""
22Block FP8 Gemm benchmark
33============================
4-
54This benchmark is come from SGLang kernels.
65https://github.com/sgl-project/sglang/blob/07f944631e747d7489fde1f11de93e503afa90ba/python/sglang/srt/layers/quantization/fp8_kernel.py#L375
7-
86"""
97
10- import functools
11- import json
12- import logging
13- import os
14- from typing import Any , Dict , List , Optional
8+ from typing import List
159
1610import torch
1711import triton
1812import triton .language as tl
1913
2014import triton_kernels_benchmark as benchmark_suit
2115
22- logger = logging .getLogger (__name__ )
16+ DEVICE_NAME = torch .xpu .get_device_name ()
17+ DEVICE_TOTAL_MEMORY = torch .xpu .get_device_properties ().total_memory
2318
2419
2520@triton .jit
@@ -107,42 +102,6 @@ def _w8a8_block_fp8_matmul(
107102 tl .store (c_ptrs , c , mask = c_mask )
108103
109104
110- @functools .lru_cache
111- def get_w8a8_block_fp8_configs (N : int , K : int , block_n : int , block_k : int ) -> Optional [Dict [int , Any ]]:
112- """
113- Return optimized configurations for the w8a8 block fp8 kernel.
114-
115- The return value will be a dictionary that maps an irregular grid of
116- batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
117- kernel on a given batch size bs, the closest batch size in the grid should
118- be picked and the associated configuration chosen to invoke the kernel.
119- """
120-
121- # First look up if an optimized configuration is available in the configs
122- # directory
123- device_name = torch .xpu .get_device_name (0 ).replace (" " , "_" )
124- json_file_name = f"N={ N } ,K={ K } ,device_name={ device_name } ,dtype=fp8_w8a8,block_shape=[{ block_n } , { block_k } ].json"
125-
126- config_file_path = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "configs" , json_file_name )
127- if os .path .exists (config_file_path ):
128- with open (config_file_path , "r" , encoding = "utf-8" ) as f :
129- logger .info (
130- "Using configuration from %s for W8A8 Block FP8 kernel." ,
131- config_file_path ,
132- )
133- # If a configuration has been found, return it
134- return {int (key ): val for key , val in json .load (f ).items ()}
135-
136- # If no optimized configuration is available, we will use the default
137- # configuration
138- logger .warning (
139- ("Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
140- "Config file not found at %s" ),
141- config_file_path ,
142- )
143- return None
144-
145-
146105def w8a8_block_fp8_matmul (
147106 A : torch .Tensor ,
148107 B : torch .Tensor ,
@@ -152,18 +111,15 @@ def w8a8_block_fp8_matmul(
152111 output_dtype : torch .dtype = torch .float16 ,
153112) -> torch .Tensor :
154113 """This function performs matrix multiplication with block-wise quantization.
155-
156114 It takes two input tensors `A` and `B` with scales `As` and `Bs`.
157115 The output is returned in the specified `output_dtype`.
158-
159116 Args:
160117 A: The input tensor, e.g., activation.
161118 B: The input tensor, e.g., weight.
162119 As: The per-token-group quantization scale for `A`.
163120 Bs: The per-block quantization scale for `B`.
164121 block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
165122 output_dytpe: The dtype of the returned tensor.
166-
167123 Returns:
168124 torch.Tensor: The result of matmul.
169125 """
@@ -183,22 +139,16 @@ def w8a8_block_fp8_matmul(
183139 C_shape = A .shape [:- 1 ] + (N , )
184140 C = A .new_empty (C_shape , dtype = output_dtype )
185141
186- configs = get_w8a8_block_fp8_configs (N , K , block_size [0 ], block_size [1 ])
187- if configs :
188- # If an optimal configuration map has been found, look up the
189- # optimal config
190- config = configs [min (configs .keys (), key = lambda x : abs (x - M ))]
191- else :
192- # Default config
193- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
194- config = {
195- "BLOCK_SIZE_M" : 64 ,
196- "BLOCK_SIZE_N" : block_size [0 ],
197- "BLOCK_SIZE_K" : block_size [1 ],
198- "GROUP_SIZE_M" : 32 ,
199- "num_warps" : 4 ,
200- "num_stages" : 3 ,
201- }
142+ # Default config
143+ # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
144+ config = {
145+ "BLOCK_SIZE_M" : 64 ,
146+ "BLOCK_SIZE_N" : block_size [0 ],
147+ "BLOCK_SIZE_K" : block_size [1 ],
148+ "GROUP_SIZE_M" : 32 ,
149+ "num_warps" : 4 ,
150+ "num_stages" : 3 ,
151+ }
202152
203153 def grid (META ):
204154 return (triton .cdiv (M , META ["BLOCK_SIZE_M" ]) * triton .cdiv (N , META ["BLOCK_SIZE_N" ]), )
@@ -232,7 +182,7 @@ def grid(META):
232182 return C
233183
234184
235- # Reference path
185+ # For test
236186def native_w8a8_block_fp8_matmul (A , B , As , Bs , block_size , output_dtype = torch .float16 ):
237187 """This function performs matrix multiplication with block-wise quantization using native torch.
238188
@@ -284,55 +234,51 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
284234 return C
285235
286236
287- X_VALS = [[1 , 1024 * i , 1024 * i , 1024 * i ] for i in [1 , 2 , 4 , 8 ]] + [
288- [1 , 1 , 13824 , 5120 ],
289- [1 , 4 , 12288 , 4096 ],
290- [1 , 512 , 8192 , 8192 ],
291- [1 , 512 , 8192 , 32768 ],
292- [1 , 512 , 32768 , 8192 ],
293- [1 , 1024 , 8192 , 16384 ],
294- [1 , 1024 , 8192 , 28672 ],
295- [1 , 3072 , 3072 , 4096 ], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
296- [1 , 4096 , 8192 , 16384 ],
297- [1 , 8192 , 1024 , 16384 ],
298- [1 , 8192 , 4096 , 16384 ],
299- [1 , 16384 , 1024 , 8192 ],
300- [1 , 16384 , 4096 , 8192 ],
301- [1 , 16384 , 8192 , 1024 ],
302- [1 , 16384 , 8192 , 4096 ],
303- [4 , 32768 , 128 , 4096 ],
304- [4 , 32768 , 4096 , 128 ],
305- [32 , 4096 , 128 , 4096 ],
306- [4096 , 8 , 128 , 16384 ],
307- [4096 , 8 , 16384 , 128 ],
308- ]
309-
310- DEVICE_NAME = torch .xpu .get_device_name ()
311- DEVICE_TOTAL_MEMORY = torch .xpu .get_device_properties ().total_memory
312-
313-
314- def is_enough_memory (x_val ):
315- # x_val: (B, M, N, K)
316- B , M , N , K = x_val
317- # a: (B, M, K) float8_e4m3
318- # b: (B, N, K) float8_e4m3
319- # c: (B, M, N) bfloat16
320- # pytorch reference: (B, M, N) float32
321- required_memory = B * M * K * 1 + B * N * K * 1 + B * M * N * 2 * 2
237+ def has_enough_memory (x_val ):
238+ # x_val: (M, N, K)
239+ M , N , K = x_val
240+ # a: (M, K) float8_e4m3
241+ # b: (N, K) float8_e4m3
242+ # c: (M, N) bfloat16
243+ # pytorch reference: (M, N) float32
244+ required_memory = M * K * 1 + N * K * 1 + M * N * 2 * 2
322245 enough_memory = required_memory < DEVICE_TOTAL_MEMORY
323246 if not enough_memory :
324247 print (f"'{ x_val } ' combination skipped for '{ DEVICE_NAME } '; { required_memory = } but { DEVICE_TOTAL_MEMORY = } " )
325248 return enough_memory
326249
327250
328- X_VALS = [x_val for x_val in X_VALS if is_enough_memory (x_val )]
251+ X_VALS = [[1024 * i , 1024 * i , 1024 * i ] for i in [1 , 2 , 4 , 8 ]] + [
252+ [1 , 13824 , 5120 ],
253+ [4 , 12288 , 4096 ],
254+ [512 , 8192 , 8192 ],
255+ [512 , 8192 , 32768 ],
256+ [512 , 32768 , 8192 ],
257+ [1024 , 8192 , 16384 ],
258+ [1024 , 8192 , 28672 ],
259+ [3072 , 3072 , 4096 ],
260+ [4096 , 8192 , 16384 ],
261+ [8192 , 1024 , 16384 ],
262+ [8192 , 4096 , 16384 ],
263+ [16384 , 1024 , 8192 ],
264+ [16384 , 4096 , 8192 ],
265+ [16384 , 8192 , 1024 ],
266+ [16384 , 8192 , 4096 ],
267+ [32768 , 128 , 4096 ],
268+ [32768 , 4096 , 128 ],
269+ [4096 , 128 , 4096 ],
270+ [8 , 128 , 16384 ],
271+ [8 , 16384 , 128 ],
272+ ]
273+
274+ X_VALS = [x_val for x_val in X_VALS if has_enough_memory (x_val )]
329275
330276
331277# Benchmark Performance
332278@benchmark_suit .perf_report (
333279 benchmark_suit .Benchmark (
334280 # argument names to use as an x-axis for the plot
335- x_names = ["B" , " M" , "N" , "K" ],
281+ x_names = ["M" , "N" , "K" ],
336282 # different possible values for `x_name`
337283 x_vals = X_VALS ,
338284 line_arg = "provider" ,
@@ -342,16 +288,14 @@ def is_enough_memory(x_val):
342288 line_names = ["Triton" ],
343289 # line styles
344290 ylabel = ["GB/s" , "TFlops" ], # label name for the y-axis
345- plot_name = "matmul -performance" ,
291+ plot_name = "sglang-fp8-gemm -performance" ,
346292 # name for the plot. Used also as a file name for saving the plot.
347293 args = {},
348294 ))
349- def benchmark (B , M , N , K , provider ):
350- assert provider == "triton"
295+ def benchmark (M , N , K , provider ):
296+ torch . manual_seed ( 0 )
351297
352298 block_size = [128 , 128 ]
353-
354- torch .manual_seed (0 )
355299 factor_for_scale = 1e-2
356300 fp8_info = torch .finfo (torch .float8_e4m3fn )
357301 fp8_max , fp8_min = fp8_info .max , fp8_info .min
@@ -371,15 +315,18 @@ def benchmark(B, M, N, K, provider):
371315
372316 quantiles = [0.5 , 0.0 , 1.0 ]
373317
374- triton_fn = lambda : w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
375- torch_fn = lambda : native_w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
376- rtol = 1e-2
377- atol = 3e-4
378- benchmark_suit .assert_close (triton_fn , torch_fn , atol = atol , rtol = rtol , err_msg = "triton to torch" )
379- _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
318+ if provider == "triton" :
319+ triton_fn = lambda : w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
320+ torch_fn = lambda : native_w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
321+ benchmark_suit .assert_close (triton_fn , torch_fn , atol = 3e-4 , rtol = 1e-2 , err_msg = "triton to torch" )
322+ _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 ,
323+ quantiles = quantiles )
324+
325+ else :
326+ raise NotImplementedError (f"Unsupported provider { provider } " )
380327
381- tflops = lambda ms : 2 * B * M * N * K * (1e-12 ) / (ms * 1e-3 )
382- gbps = lambda ms : B * (( M * K + K * N ) + 2.0 * (M * N ) ) * (1e-9 ) / (ms * 1e-3 )
328+ tflops = lambda ms : 2 * M * N * K * (1e-12 ) / (ms * 1e-3 )
329+ gbps = lambda ms : ( M * K + K * N ) + 2.0 * (M * N ) * (1e-9 ) / (ms * 1e-3 )
383330
384331 return (gbps (mean_ms ), gbps (max_ms ), gbps (min_ms )), (tflops (mean_ms ), tflops (max_ms ), tflops (min_ms )), cv
385332
0 commit comments