Skip to content

Commit 1b3ec4c

Browse files
committed
Integrate sglang prefill/decode/extend kernel to benchmarks
Port prefill attn and decode attn from sglang Add validation temp add extend attention disable debug ir dump Update three stage attention benchmark Add sglang kernel benchmark to action use 1e-3 atol remove sglang benchmark from triton-benchmarks Fix setup bdist_wheel Add sglang to thirdparty test
1 parent 55a2172 commit 1b3ec4c

File tree

7 files changed

+625
-2
lines changed

7 files changed

+625
-2
lines changed

.github/pins/sglang.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.4.5

.github/workflows/third-party-benchmarks.yml

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,14 @@ jobs:
7272
- name: Setup Triton
7373
uses: ./.github/actions/setup-triton
7474

75-
- name: Install benchmark dependencies
75+
- name: Install benchmarks
7676
id: install
77+
run: |
78+
cd benchmarks
79+
pip install .
80+
81+
- name: Install benchmark dependencies
82+
id: install_deps
7783
run: |
7884
pip install transformers pandas pytest
7985
@@ -83,7 +89,7 @@ jobs:
8389
echo "REPORTS=$PWD/reports" >> $GITHUB_ENV
8490
8591
- name: Run Liger-Kernel benchmarks
86-
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
92+
if: ${{ steps.install_deps.outcome == 'success' && !cancelled() }}
8793
run: |
8894
source ./scripts/capture-hw-details.sh
8995
@@ -102,6 +108,38 @@ jobs:
102108
# Return the captured return code at the end
103109
exit "$RET_CODE"
104110
111+
- name: Install SGLANG
112+
run: |
113+
SGLANG_PIN="$(<.github/pins/sglang.txt)"
114+
pip install sglang==$SGLANG_PIN
115+
116+
- name: Run SGLANG attention prefill stage benchmark
117+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefill_attention_benchmark.py') }}
118+
run: |
119+
cd benchmarks/third_party/sglang
120+
python prefill_attention_benchmark --reports $REPORTS
121+
122+
source ../../scripts/capture-hw-details.sh
123+
python ../../scripts/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
124+
125+
- name: Run SGLANG attention decode stage benchmark
126+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'decode_attention_benchmark.py') }}
127+
run: |
128+
cd benchmarks/third_party/sglang
129+
python decode_attention_benchmark --reports $REPORTS
130+
131+
source ../../scripts/capture-hw-details.sh
132+
python ../../scripts/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
133+
134+
- name: Run SGLANG attention append stage benchmark
135+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'decode_attention_benchmark.py') }}
136+
run: |
137+
cd benchmarks/third_party/sglang
138+
python extended_attention_benchmark --reports $REPORTS
139+
140+
source ../../scripts/capture-hw-details.sh
141+
python ../../scripts/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142+
105143
- name: Upload benchmark reports
106144
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
107145
uses: actions/upload-artifact@v4

.github/workflows/third-party-tests.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ jobs:
9696
9797
pytest Liger-Kernel/test/
9898
99+
- name: Run SGLANG tests
100+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
101+
run: |
102+
pip install transformers pandas pytest openai
103+
104+
SGLANG_PIN="$(<.github/pins/sglang.txt)"
105+
pip install datasets decord sglang==$SGLANG_PIN
106+
git clone https://github.com/sgl-project/sglang.git --branch $SGLANG_PIN --single-branch
107+
108+
cd sglang
109+
git apply ../benchmarks/third_party/sglang/sglang.patch
110+
pytest sglang/test/srt/test_triton_attention_kernels.py
111+
99112
- name: Upload test report
100113
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
101114
uses: actions/upload-artifact@v4
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import torch
2+
3+
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
4+
5+
import triton_kernels_benchmark as benchmark_suit
6+
7+
8+
def gen_args(BATCH, N_CTX, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, dtype, device):
9+
10+
total_tokens = BATCH * N_CTX
11+
sm_scale = 1.0 / (HEAD_DIM**0.5)
12+
max_kv_splits = 8
13+
num_kv_splits = torch.full((BATCH, ), 4, dtype=torch.int32, device=device)
14+
15+
# q represents the new token being generated, one per batch
16+
q = torch.randn(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
17+
18+
# k_buffer and v_buffer represent all previous tokens
19+
k_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
20+
v_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
21+
22+
# o will have the same shape as q
23+
o = torch.zeros(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
24+
25+
b_seq_len = torch.full((BATCH, ), N_CTX, device=device)
26+
27+
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
28+
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len[:BATCH], dim=0)
29+
kv_indices = torch.arange(total_tokens, device=device)
30+
31+
attn_logits = torch.empty(
32+
(BATCH, Q_HEAD_NUM, max_kv_splits, HEAD_DIM),
33+
dtype=torch.float32,
34+
device=device,
35+
)
36+
attn_lse = torch.empty(
37+
(BATCH, Q_HEAD_NUM, max_kv_splits),
38+
dtype=torch.float32,
39+
device=device,
40+
)
41+
42+
return (q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits,
43+
sm_scale)
44+
45+
46+
# pylint: disable=unused-argument
47+
@benchmark_suit.perf_report(
48+
benchmark_suit.Benchmark(
49+
# argument names to use as an x-axis for the plot
50+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'],
51+
x_vals=[ #
52+
[bs, [1024, 64], 32, 8, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
53+
] + [ #
54+
[bs, [1024, 64], 32, 32, 96, 'fwd', False] for bs in [1, 16, 32, 64, 128]
55+
] + [ #
56+
[bs, [1024, 64], 28, 4, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
57+
],
58+
line_arg='provider',
59+
# argument name whose value corresponds to a different line in the plot
60+
# possible values for `line_arg``
61+
line_vals=[
62+
'triton',
63+
],
64+
# label name for the lines
65+
line_names=[
66+
'Triton',
67+
],
68+
# line styles
69+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
70+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
71+
plot_name='decode-attn-performance',
72+
# name for the plot. Used also as a file name for saving the plot.
73+
args={},
74+
))
75+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
76+
torch.manual_seed(0)
77+
dtype = torch.bfloat16
78+
quantiles = [0.5, 0.0, 1.0]
79+
N_CTX = sum(SEQ_LENS)
80+
81+
q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args(
82+
B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
83+
84+
if provider == 'triton':
85+
triton_fn = lambda: decode_attention_fwd(
86+
q,
87+
k_buffer,
88+
v_buffer,
89+
o,
90+
kv_indptr,
91+
kv_indices,
92+
attn_logits,
93+
attn_lse,
94+
num_kv_splits,
95+
max_kv_splits,
96+
sm_scale,
97+
)
98+
99+
# TODO: decode attention should have the validation function
100+
if VALIDATE:
101+
raise NotImplementedError('Validation is not implemented for decode stage')
102+
103+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
104+
105+
else:
106+
raise NotImplementedError(f'Unsupported provider {provider}')
107+
108+
tflops = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * N_CTX * D * (1e-12) / (ms * 1e-3)
109+
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)
110+
111+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
112+
113+
114+
if __name__ == '__main__':
115+
benchmark.run(show_plots=False, print_data=True)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
3+
extend_attention_fwd,
4+
redundant_attention,
5+
)
6+
import triton_kernels_benchmark as benchmark_suit
7+
8+
9+
def gen_args(BATCH, N_CTX, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, dtype, device):
10+
11+
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device=device)
12+
b_seq_len_extend = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device=device)
13+
b_seq_len = b_seq_len_prefix + b_seq_len_extend
14+
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
15+
16+
b_req_idx = torch.arange(BATCH, dtype=torch.int32, device=device)
17+
b_start_loc = torch.zeros((BATCH, ), dtype=torch.int32, device=device)
18+
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
19+
b_start_loc_extend = torch.zeros((BATCH, ), dtype=torch.int32, device=device)
20+
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
21+
22+
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
23+
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_prefix[:BATCH], dim=0)
24+
kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device=device)
25+
26+
for i in range(BATCH):
27+
kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i])
28+
29+
total_token_num = torch.sum(b_seq_len).item()
30+
extend_token_num = torch.sum(b_seq_len_extend).item()
31+
k_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
32+
device=device).normal_(mean=0.1, std=0.2)
33+
v_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
34+
device=device).normal_(mean=0.1, std=0.2)
35+
36+
k_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
37+
v_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
38+
q_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
39+
for i in range(BATCH):
40+
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
41+
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
42+
extend_start = b_start_loc_extend[i]
43+
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
44+
k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer]
45+
v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer]
46+
q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], Q_HEAD_NUM, HEAD_DIM), dtype=dtype,
47+
device=device).normal_(mean=0.1, std=0.2)
48+
49+
o_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
50+
o_redundant = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
51+
52+
b_seq_len_extend = b_seq_len - b_seq_len_prefix
53+
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
54+
qo_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
55+
qo_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_extend[:BATCH], dim=0)
56+
57+
params = []
58+
params.append((q_extend, k_extend, v_extend, o_extend, o_redundant))
59+
params.append((k_buffer, v_buffer))
60+
params.append((qo_indptr, kv_indptr, kv_indices, max_len_extend))
61+
params.append((b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch))
62+
return params
63+
64+
65+
# pylint: disable=unused-argument
66+
@benchmark_suit.perf_report(
67+
benchmark_suit.Benchmark(
68+
# argument names to use as an x-axis for the plot
69+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'],
70+
x_vals=[ #
71+
[bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
72+
] + [ #
73+
[bs, [1024, 128, 512], 32, 32, 96, 'fwd', True] for bs in [1, 16, 32, 64, 128]
74+
] + [ #
75+
[bs, [1024, 128, 512], 28, 4, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
76+
],
77+
line_arg='provider',
78+
# argument name whose value corresponds to a different line in the plot
79+
# possible values for `line_arg``
80+
line_vals=[
81+
'triton',
82+
],
83+
# label name for the lines
84+
line_names=[
85+
'Triton',
86+
],
87+
# line styles
88+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
89+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
90+
plot_name='extended-attn-performance',
91+
# name for the plot. Used also as a file name for saving the plot.
92+
args={},
93+
))
94+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
95+
torch.manual_seed(0)
96+
97+
dtype = torch.bfloat16
98+
N_CTX = sum(SEQ_LENS)
99+
100+
params = gen_args(B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
101+
q_extend, k_extend, v_extend, o_extend, o_redundant = params[0]
102+
k_buffer, v_buffer = params[1]
103+
qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2]
104+
b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch = params[3]
105+
custom_mask = None
106+
mask_indptr = None
107+
108+
quantiles = [0.5, 0.0, 1.0]
109+
if provider == 'triton':
110+
111+
def triton_fn():
112+
extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr,
113+
kv_indices, custom_mask, mask_indptr, max_len_extend)
114+
return o_extend
115+
116+
if VALIDATE:
117+
118+
def refer_fn():
119+
redundant_attention(q_extend, o_redundant, k_buffer, v_buffer, b_req_idx, b_start_loc, b_seq_len,
120+
b_seq_len_prefix, max_len_in_batch)
121+
return o_redundant
122+
123+
benchmark_suit.assert_close(triton_fn, refer_fn, atol=1e-3, rtol=1e-2, err_msg='extend to refer')
124+
125+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
126+
127+
else:
128+
raise NotImplementedError(f'Unsupported provider {provider}')
129+
130+
tflops = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * N_CTX * D * (1e-12) / (ms * 1e-3)
131+
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)
132+
133+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
134+
135+
136+
if __name__ == '__main__':
137+
benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)