Skip to content

Commit c14abd5

Browse files
airMengleonling-ll
authored andcommitted
add sglang block fp8 gemm into benchmark
fix bugs rtol atol Move fp8 gemm to sglang benchmark
1 parent 935cef5 commit c14abd5

File tree

3 files changed

+347
-14
lines changed

3 files changed

+347
-14
lines changed

.github/workflows/third-party-benchmarks.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ jobs:
111111
- name: Install SGLANG
112112
run: |
113113
git clone https://github.com/sgl-project/sglang.git
114-
pip install sglang/python[srt_xpu]
114+
pip install sglang/python[dev_xpu]
115115
116116
- name: Run SGLANG attention prefill stage benchmark
117117
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
@@ -140,6 +140,15 @@ jobs:
140140
source ../../../scripts/capture-hw-details.sh
141141
python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142142
143+
- name: Run SGLANG Block FP8 GEMM benchmark
144+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'block_fp8_gemm_benchmark.py') }}
145+
run: |
146+
cd benchmarks/third_party/sglang
147+
python block_fp8_gemm_benchmark.py --reports $REPORTS
148+
149+
source ../../../scripts/capture-hw-details.sh
150+
python ../../../scripts/build_report.py $REPORTS/sglang-fp8-gemm-performance.csv $REPORTS/sglang-fp8-gemm-triton-report.csv --benchmark sglang-block-fp8-gemm --compiler triton --param_cols "M,N,K" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
151+
143152
- name: Upload benchmark reports
144153
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
145154
uses: actions/upload-artifact@v4
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
"""
2+
Block FP8 Gemm benchmark
3+
============================
4+
This benchmark is come from SGLang kernels.
5+
https://github.com/sgl-project/sglang/blob/07f944631e747d7489fde1f11de93e503afa90ba/python/sglang/srt/layers/quantization/fp8_kernel.py#L375
6+
"""
7+
8+
from typing import List
9+
10+
import torch
11+
import triton
12+
import triton.language as tl
13+
14+
import triton_kernels_benchmark as benchmark_suit
15+
16+
DEVICE_NAME = torch.xpu.get_device_name()
17+
DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory
18+
19+
20+
@triton.jit
21+
def _w8a8_block_fp8_matmul(
22+
# Pointers to inputs and output
23+
A,
24+
B,
25+
C,
26+
As,
27+
Bs,
28+
# Shape for matmul
29+
M,
30+
N,
31+
K,
32+
# Block size for block-wise quantization
33+
group_n,
34+
group_k,
35+
# Stride for inputs and output
36+
stride_am,
37+
stride_ak,
38+
stride_bk,
39+
stride_bn,
40+
stride_cm,
41+
stride_cn,
42+
stride_As_m,
43+
stride_As_k,
44+
stride_Bs_k,
45+
stride_Bs_n,
46+
# Meta-parameters
47+
BLOCK_SIZE_M: tl.constexpr,
48+
BLOCK_SIZE_N: tl.constexpr,
49+
BLOCK_SIZE_K: tl.constexpr,
50+
GROUP_SIZE_M: tl.constexpr,
51+
):
52+
"""Triton-accelerated function used to perform linear operations (dot
53+
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
54+
tensor `C`.
55+
"""
56+
57+
pid = tl.program_id(axis=0)
58+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
59+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
60+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
61+
group_id = pid // num_pid_in_group
62+
first_pid_m = group_id * GROUP_SIZE_M
63+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
64+
pid_m = first_pid_m + (pid % group_size_m)
65+
pid_n = (pid % num_pid_in_group) // group_size_m
66+
67+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
68+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
69+
offs_k = tl.arange(0, BLOCK_SIZE_K)
70+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
71+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
72+
73+
As_ptrs = As + offs_am * stride_As_m
74+
offs_bsn = offs_bn // group_n
75+
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
76+
77+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
78+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
79+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
80+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
81+
82+
k_start = k * BLOCK_SIZE_K
83+
offs_ks = k_start // group_k
84+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
85+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
86+
87+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
88+
a_ptrs += BLOCK_SIZE_K * stride_ak
89+
b_ptrs += BLOCK_SIZE_K * stride_bk
90+
91+
if C.dtype.element_ty == tl.bfloat16:
92+
c = accumulator.to(tl.bfloat16)
93+
elif C.dtype.element_ty == tl.float16:
94+
c = accumulator.to(tl.float16)
95+
else:
96+
c = accumulator.to(tl.float32)
97+
98+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
99+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
100+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
101+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
102+
tl.store(c_ptrs, c, mask=c_mask)
103+
104+
105+
def w8a8_block_fp8_matmul(
106+
A: torch.Tensor,
107+
B: torch.Tensor,
108+
As: torch.Tensor,
109+
Bs: torch.Tensor,
110+
block_size: List[int],
111+
output_dtype: torch.dtype = torch.float16,
112+
) -> torch.Tensor:
113+
"""This function performs matrix multiplication with block-wise quantization.
114+
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
115+
The output is returned in the specified `output_dtype`.
116+
Args:
117+
A: The input tensor, e.g., activation.
118+
B: The input tensor, e.g., weight.
119+
As: The per-token-group quantization scale for `A`.
120+
Bs: The per-block quantization scale for `B`.
121+
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
122+
output_dytpe: The dtype of the returned tensor.
123+
Returns:
124+
torch.Tensor: The result of matmul.
125+
"""
126+
assert len(block_size) == 2
127+
block_n, block_k = block_size[0], block_size[1]
128+
129+
assert A.shape[-1] == B.shape[-1]
130+
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
131+
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
132+
M = A.numel() // A.shape[-1]
133+
134+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
135+
N, K = B.shape
136+
assert triton.cdiv(N, block_n) == Bs.shape[0]
137+
assert triton.cdiv(K, block_k) == Bs.shape[1]
138+
139+
C_shape = A.shape[:-1] + (N, )
140+
C = A.new_empty(C_shape, dtype=output_dtype)
141+
142+
# Default config
143+
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
144+
config = {
145+
"BLOCK_SIZE_M": 64,
146+
"BLOCK_SIZE_N": block_size[0],
147+
"BLOCK_SIZE_K": block_size[1],
148+
"GROUP_SIZE_M": 32,
149+
"num_warps": 4,
150+
"num_stages": 3,
151+
}
152+
153+
def grid(META):
154+
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
155+
156+
kernel = _w8a8_block_fp8_matmul
157+
158+
kernel[grid](
159+
A,
160+
B,
161+
C,
162+
As,
163+
Bs,
164+
M,
165+
N,
166+
K,
167+
block_n,
168+
block_k,
169+
A.stride(-2),
170+
A.stride(-1),
171+
B.stride(1),
172+
B.stride(0),
173+
C.stride(-2),
174+
C.stride(-1),
175+
As.stride(-2),
176+
As.stride(-1),
177+
Bs.stride(1),
178+
Bs.stride(0),
179+
**config,
180+
)
181+
182+
return C
183+
184+
185+
# For test
186+
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
187+
"""This function performs matrix multiplication with block-wise quantization using native torch.
188+
189+
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
190+
The output is returned in the specified `output_dtype`.
191+
"""
192+
193+
A = A.to(torch.float32)
194+
B = B.to(torch.float32)
195+
assert A.shape[-1] == B.shape[-1]
196+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
197+
assert len(block_size) == 2
198+
block_n, block_k = block_size[0], block_size[1]
199+
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
200+
assert A.shape[:-1] == As.shape[:-1]
201+
202+
M = A.numel() // A.shape[-1]
203+
N, K = B.shape
204+
origin_C_shape = A.shape[:-1] + (N, )
205+
A = A.reshape(M, A.shape[-1])
206+
As = As.reshape(M, As.shape[-1])
207+
n_tiles = (N + block_n - 1) // block_n
208+
k_tiles = (K + block_k - 1) // block_k
209+
assert n_tiles == Bs.shape[0]
210+
assert k_tiles == Bs.shape[1]
211+
212+
C_shape = (M, N)
213+
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
214+
215+
A_tiles = [A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)]
216+
B_tiles = [[
217+
B[
218+
j * block_n:min((j + 1) * block_n, N),
219+
i * block_k:min((i + 1) * block_k, K),
220+
] for i in range(k_tiles)
221+
] for j in range(n_tiles)]
222+
C_tiles = [C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)]
223+
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
224+
225+
for i in range(k_tiles):
226+
for j in range(n_tiles):
227+
a = A_tiles[i]
228+
b = B_tiles[j][i]
229+
c = C_tiles[j]
230+
s = As_tiles[i] * Bs[j][i]
231+
c[:, :] += torch.matmul(a, b.t()) * s
232+
233+
C = C.reshape(origin_C_shape).to(output_dtype)
234+
return C
235+
236+
237+
def has_enough_memory(x_val):
238+
# x_val: (M, N, K)
239+
M, N, K = x_val
240+
# a: (M, K) float8_e4m3
241+
# b: (N, K) float8_e4m3
242+
# c: (M, N) bfloat16
243+
# pytorch reference: (M, N) float32
244+
required_memory = M * K * 1 + N * K * 1 + M * N * 2 * 2
245+
enough_memory = required_memory < DEVICE_TOTAL_MEMORY
246+
if not enough_memory:
247+
print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}")
248+
return enough_memory
249+
250+
251+
X_VALS = [[1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [
252+
[1, 13824, 5120],
253+
[4, 12288, 4096],
254+
[512, 8192, 8192],
255+
[512, 8192, 32768],
256+
[512, 32768, 8192],
257+
[1024, 8192, 16384],
258+
[1024, 8192, 28672],
259+
[3072, 3072, 4096],
260+
[4096, 8192, 16384],
261+
[8192, 1024, 16384],
262+
[8192, 4096, 16384],
263+
[16384, 1024, 8192],
264+
[16384, 4096, 8192],
265+
[16384, 8192, 1024],
266+
[16384, 8192, 4096],
267+
[32768, 128, 4096],
268+
[32768, 4096, 128],
269+
[4096, 128, 4096],
270+
[8, 128, 16384],
271+
[8, 16384, 128],
272+
]
273+
274+
X_VALS = [x_val for x_val in X_VALS if has_enough_memory(x_val)]
275+
276+
277+
# Benchmark Performance
278+
@benchmark_suit.perf_report(
279+
benchmark_suit.Benchmark(
280+
# argument names to use as an x-axis for the plot
281+
x_names=["M", "N", "K"],
282+
# different possible values for `x_name`
283+
x_vals=X_VALS,
284+
line_arg="provider",
285+
# argument name whose value corresponds to a different line in the plot
286+
line_vals=["triton"],
287+
# label name for the lines
288+
line_names=["Triton"],
289+
# line styles
290+
ylabel=["GB/s", "TFlops"], # label name for the y-axis
291+
plot_name="sglang-fp8-gemm-performance",
292+
# name for the plot. Used also as a file name for saving the plot.
293+
args={},
294+
))
295+
def benchmark(M, N, K, provider):
296+
torch.manual_seed(0)
297+
298+
block_size = [128, 128]
299+
factor_for_scale = 1e-2
300+
fp8_info = torch.finfo(torch.float8_e4m3fn)
301+
fp8_max, fp8_min = fp8_info.max, fp8_info.min
302+
303+
A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max
304+
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
305+
306+
B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max
307+
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
308+
309+
block_n, block_k = block_size[0], block_size[1]
310+
n_tiles = (N + block_n - 1) // block_n
311+
k_tiles = (K + block_k - 1) // block_k
312+
313+
As = torch.rand(M, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale
314+
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale
315+
316+
quantiles = [0.5, 0.0, 1.0]
317+
318+
if provider == "triton":
319+
triton_fn = lambda: w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size)
320+
torch_fn = lambda: native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size)
321+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=3e-4, rtol=1e-2, err_msg="triton to torch")
322+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
323+
quantiles=quantiles)
324+
325+
else:
326+
raise NotImplementedError(f"Unsupported provider {provider}")
327+
328+
tflops = lambda ms: 2 * M * N * K * (1e-12) / (ms * 1e-3)
329+
gbps = lambda ms: (M * K + K * N) + 2.0 * (M * N) * (1e-9) / (ms * 1e-3)
330+
331+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
332+
333+
334+
if __name__ == "__main__":
335+
benchmark.run(show_plots=False, print_data=True)

benchmarks/third_party/sglang/decode_attention_benchmark.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,8 @@ def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
8282
B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
8383

8484
if provider == 'triton':
85-
triton_fn = lambda: decode_attention_fwd(
86-
q,
87-
k_buffer,
88-
v_buffer,
89-
o,
90-
kv_indptr,
91-
kv_indices,
92-
attn_logits,
93-
attn_lse,
94-
num_kv_splits,
95-
max_kv_splits,
96-
sm_scale,
97-
)
85+
triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse,
86+
num_kv_splits, max_kv_splits, sm_scale)
9887

9988
# TODO: decode attention should have the validation function
10089
if VALIDATE:

0 commit comments

Comments
 (0)