Skip to content

Commit 935cef5

Browse files
committed
Integrate sglang prefill/decode/extend kernel to benchmarks
Port prefill attn and decode attn from sglang Add validation temp add extend attention disable debug ir dump Update three stage attention benchmark Add sglang kernel benchmark to action use 1e-3 atol remove sglang benchmark from triton-benchmarks Fix setup bdist_wheel Add sglang to thirdparty test Address review comments Remove sglang from tests Fix CI Address review comments Integrate sglang prefill/decode/extend kernel to benchmarks Port prefill attn and decode attn from sglang Add validation temp add extend attention disable debug ir dump Update three stage attention benchmark Add sglang kernel benchmark to action use 1e-3 atol remove sglang benchmark from triton-benchmarks Fix setup bdist_wheel Add sglang to thirdparty test Address review comments Remove sglang from tests Adjust params term Adjust tflops computation
1 parent e98561c commit 935cef5

File tree

5 files changed

+391
-3
lines changed

5 files changed

+391
-3
lines changed

.github/workflows/third-party-benchmarks.yml

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,14 @@ jobs:
7272
- name: Setup Triton
7373
uses: ./.github/actions/setup-triton
7474

75-
- name: Install benchmark dependencies
75+
- name: Install benchmarks
7676
id: install
77+
run: |
78+
cd benchmarks
79+
pip install .
80+
81+
- name: Install benchmark dependencies
82+
id: install_deps
7783
run: |
7884
pip install transformers pandas pytest
7985
@@ -83,7 +89,7 @@ jobs:
8389
echo "REPORTS=$PWD/reports" >> $GITHUB_ENV
8490
8591
- name: Run Liger-Kernel benchmarks
86-
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
92+
if: ${{ steps.install_deps.outcome == 'success' && !cancelled() }}
8793
run: |
8894
source ./scripts/capture-hw-details.sh
8995
@@ -102,6 +108,38 @@ jobs:
102108
# Return the captured return code at the end
103109
exit "$RET_CODE"
104110
111+
- name: Install SGLANG
112+
run: |
113+
git clone https://github.com/sgl-project/sglang.git
114+
pip install sglang/python[srt_xpu]
115+
116+
- name: Run SGLANG attention prefill stage benchmark
117+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
118+
run: |
119+
cd benchmarks/third_party/sglang
120+
python prefill_attention_benchmark.py --reports $REPORTS
121+
122+
source ../../../scripts/capture-hw-details.sh
123+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
124+
125+
- name: Run SGLANG attention decode stage benchmark
126+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
127+
run: |
128+
cd benchmarks/third_party/sglang
129+
python decode_attention_benchmark.py --reports $REPORTS
130+
131+
source ../../../scripts/capture-hw-details.sh
132+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
133+
134+
- name: Run SGLANG attention append stage benchmark
135+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
136+
run: |
137+
cd benchmarks/third_party/sglang
138+
python extended_attention_benchmark.py --reports $REPORTS
139+
140+
source ../../../scripts/capture-hw-details.sh
141+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142+
105143
- name: Upload benchmark reports
106144
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
107145
uses: actions/upload-artifact@v4
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import torch
2+
3+
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
4+
5+
import triton_kernels_benchmark as benchmark_suit
6+
7+
8+
def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
9+
10+
total_tokens = B * N_CTX
11+
sm_scale = 1.0 / (D**0.5)
12+
max_kv_splits = 8
13+
num_kv_splits = torch.full((B, ), 4, dtype=torch.int32, device=device)
14+
15+
# q represents the new token being generated, one per B
16+
q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
17+
18+
# k_buffer and v_buffer represent all previous tokens
19+
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
20+
v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
21+
22+
# o will have the same shape as q
23+
o = torch.zeros(B, H_Q, D, dtype=dtype, device=device)
24+
25+
b_seq_len = torch.full((B, ), N_CTX, device=device)
26+
27+
kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device)
28+
kv_indptr[1:B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
29+
kv_indices = torch.arange(total_tokens, device=device)
30+
31+
attn_logits = torch.empty(
32+
(B, H_Q, max_kv_splits, D),
33+
dtype=torch.float32,
34+
device=device,
35+
)
36+
attn_lse = torch.empty(
37+
(B, H_Q, max_kv_splits),
38+
dtype=torch.float32,
39+
device=device,
40+
)
41+
42+
return (q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits,
43+
sm_scale)
44+
45+
46+
# pylint: disable=unused-argument
47+
@benchmark_suit.perf_report(
48+
benchmark_suit.Benchmark(
49+
# argument names to use as an x-axis for the plot
50+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'],
51+
x_vals=[ #
52+
[bs, [1024, 64], 32, 8, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
53+
] + [ #
54+
[bs, [1024, 64], 32, 32, 96, 'fwd', False] for bs in [1, 16, 32, 64, 128]
55+
] + [ #
56+
[bs, [1024, 64], 28, 4, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
57+
],
58+
line_arg='provider',
59+
# argument name whose value corresponds to a different line in the plot
60+
# possible values for `line_arg``
61+
line_vals=[
62+
'triton',
63+
],
64+
# label name for the lines
65+
line_names=[
66+
'Triton',
67+
],
68+
# line styles
69+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
70+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
71+
plot_name='decode-attn-performance',
72+
# name for the plot. Used also as a file name for saving the plot.
73+
args={},
74+
))
75+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
76+
torch.manual_seed(0)
77+
dtype = torch.bfloat16
78+
quantiles = [0.5, 0.0, 1.0]
79+
N_CTX = sum(SEQ_LENS)
80+
81+
q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args(
82+
B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
83+
84+
if provider == 'triton':
85+
triton_fn = lambda: decode_attention_fwd(
86+
q,
87+
k_buffer,
88+
v_buffer,
89+
o,
90+
kv_indptr,
91+
kv_indices,
92+
attn_logits,
93+
attn_lse,
94+
num_kv_splits,
95+
max_kv_splits,
96+
sm_scale,
97+
)
98+
99+
# TODO: decode attention should have the validation function
100+
if VALIDATE:
101+
raise NotImplementedError('Validation is not implemented for decode stage')
102+
103+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
104+
105+
else:
106+
raise NotImplementedError(f'Unsupported provider {provider}')
107+
108+
tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * (1e-12) / (ms * 1e-3)
109+
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)
110+
111+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
112+
113+
114+
if __name__ == '__main__':
115+
benchmark.run(show_plots=False, print_data=True)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
3+
extend_attention_fwd,
4+
redundant_attention,
5+
)
6+
import triton_kernels_benchmark as benchmark_suit
7+
8+
9+
def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
10+
11+
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device)
12+
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device)
13+
b_seq_len = b_seq_len_prefix + b_seq_len_extend
14+
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
15+
16+
b_req_idx = torch.arange(B, dtype=torch.int32, device=device)
17+
b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device)
18+
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
19+
b_start_loc_extend = torch.zeros((B, ), dtype=torch.int32, device=device)
20+
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
21+
22+
kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device)
23+
kv_indptr[1:B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
24+
kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device=device)
25+
26+
for i in range(B):
27+
kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i])
28+
29+
total_token_num = torch.sum(b_seq_len).item()
30+
extend_token_num = torch.sum(b_seq_len_extend).item()
31+
k_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2)
32+
v_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2)
33+
34+
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
35+
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
36+
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
37+
for i in range(B):
38+
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
39+
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
40+
extend_start = b_start_loc_extend[i]
41+
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
42+
k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer]
43+
v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer]
44+
q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], H_Q, D), dtype=dtype,
45+
device=device).normal_(mean=0.1, std=0.2)
46+
47+
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
48+
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
49+
50+
b_seq_len_extend = b_seq_len - b_seq_len_prefix
51+
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
52+
qo_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device)
53+
qo_indptr[1:B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
54+
55+
params = []
56+
params.append((q_extend, k_extend, v_extend, o_extend, o_redundant))
57+
params.append((k_buffer, v_buffer))
58+
params.append((qo_indptr, kv_indptr, kv_indices, max_len_extend))
59+
params.append((b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch))
60+
return params
61+
62+
63+
# pylint: disable=unused-argument
64+
@benchmark_suit.perf_report(
65+
benchmark_suit.Benchmark(
66+
# argument names to use as an x-axis for the plot
67+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'],
68+
x_vals=[ #
69+
[bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
70+
] + [ #
71+
[bs, [1024, 128, 512], 32, 32, 96, 'fwd', True] for bs in [1, 16, 32, 64, 128]
72+
] + [ #
73+
[bs, [1024, 128, 512], 28, 4, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
74+
],
75+
line_arg='provider',
76+
# argument name whose value corresponds to a different line in the plot
77+
# possible values for `line_arg``
78+
line_vals=[
79+
'triton',
80+
],
81+
# label name for the lines
82+
line_names=[
83+
'Triton',
84+
],
85+
# line styles
86+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
87+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
88+
plot_name='extended-attn-performance',
89+
# name for the plot. Used also as a file name for saving the plot.
90+
args={},
91+
))
92+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
93+
torch.manual_seed(0)
94+
95+
dtype = torch.bfloat16
96+
N_CTX = sum(SEQ_LENS)
97+
98+
params = gen_args(B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
99+
q_extend, k_extend, v_extend, o_extend, o_redundant = params[0]
100+
k_buffer, v_buffer = params[1]
101+
qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2]
102+
b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch = params[3]
103+
custom_mask = None
104+
mask_indptr = None
105+
106+
quantiles = [0.5, 0.0, 1.0]
107+
if provider == 'triton':
108+
109+
def triton_fn():
110+
extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr,
111+
kv_indices, custom_mask, mask_indptr, max_len_extend)
112+
return o_extend
113+
114+
if VALIDATE:
115+
116+
def refer_fn():
117+
redundant_attention(q_extend, o_redundant, k_buffer, v_buffer, b_req_idx, b_start_loc, b_seq_len,
118+
b_seq_len_prefix, max_len_in_batch)
119+
return o_redundant
120+
121+
benchmark_suit.assert_close(triton_fn, refer_fn, atol=1e-3, rtol=1e-2, err_msg='extend to refer')
122+
123+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
124+
125+
else:
126+
raise NotImplementedError(f'Unsupported provider {provider}')
127+
128+
N_CTX_TOTAL = k_buffer.shape[0]
129+
N_CTX_EXTEND = k_extend.shape[0]
130+
tflops = lambda ms: (H_Q + H_KV) * (N_CTX_EXTEND + N_CTX_TOTAL) * N_CTX_TOTAL * D * (1e-12) / (ms * 1e-3)
131+
gbps = lambda ms: 2 * (N_CTX_EXTEND * (H_Q + H_KV) + N_CTX_TOTAL * H_KV) * D * 2 * (1e-9) / (ms * 1e-3)
132+
133+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
134+
135+
136+
if __name__ == '__main__':
137+
benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)