Skip to content

Commit 5a2076d

Browse files
committed
Address review comments
Remove sglang from tests Fix CI Address review comments Integrate sglang prefill/decode/extend kernel to benchmarks Port prefill attn and decode attn from sglang Add validation temp add extend attention disable debug ir dump Update three stage attention benchmark Add sglang kernel benchmark to action use 1e-3 atol remove sglang benchmark from triton-benchmarks Fix setup bdist_wheel Add sglang to thirdparty test Address review comments Remove sglang from tests Adjust params term Adjust tflops computation
1 parent 2092c5d commit 5a2076d

File tree

7 files changed

+55
-290
lines changed

7 files changed

+55
-290
lines changed

.github/pins/sglang.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

.github/workflows/third-party-benchmarks.yml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,35 +110,35 @@ jobs:
110110
111111
- name: Install SGLANG
112112
run: |
113-
SGLANG_PIN="$(<.github/pins/sglang.txt)"
114-
pip install sglang==$SGLANG_PIN
113+
git clone https://github.com/sgl-project/sglang.git
114+
pip install sglang/python[srt_xpu]
115115
116116
- name: Run SGLANG attention prefill stage benchmark
117-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefill_attention_benchmark.py') }}
117+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
118118
run: |
119119
cd benchmarks/third_party/sglang
120-
python prefill_attention_benchmark --reports $REPORTS
120+
python prefill_attention_benchmark.py --reports $REPORTS
121121
122-
source ../../scripts/capture-hw-details.sh
123-
python ../../scripts/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
122+
source ../../../scripts/capture-hw-details.sh
123+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
124124
125125
- name: Run SGLANG attention decode stage benchmark
126-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'decode_attention_benchmark.py') }}
126+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
127127
run: |
128128
cd benchmarks/third_party/sglang
129-
python decode_attention_benchmark --reports $REPORTS
129+
python decode_attention_benchmark.py --reports $REPORTS
130130
131-
source ../../scripts/capture-hw-details.sh
132-
python ../../scripts/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
131+
source ../../../scripts/capture-hw-details.sh
132+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
133133
134134
- name: Run SGLANG attention append stage benchmark
135-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'decode_attention_benchmark.py') }}
135+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
136136
run: |
137137
cd benchmarks/third_party/sglang
138-
python extended_attention_benchmark --reports $REPORTS
138+
python extended_attention_benchmark.py --reports $REPORTS
139139
140-
source ../../scripts/capture-hw-details.sh
141-
python ../../scripts/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
140+
source ../../../scripts/capture-hw-details.sh
141+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142142
143143
- name: Upload benchmark reports
144144
if: ${{ steps.install.outcome == 'success' && !cancelled() }}

.github/workflows/third-party-tests.yml

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,6 @@ jobs:
9696
9797
pytest Liger-Kernel/test/
9898
99-
- name: Run SGLANG tests
100-
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
101-
run: |
102-
pip install transformers pandas pytest openai
103-
104-
SGLANG_PIN="$(<.github/pins/sglang.txt)"
105-
pip install datasets decord sglang==$SGLANG_PIN
106-
git clone https://github.com/sgl-project/sglang.git --branch $SGLANG_PIN --single-branch
107-
108-
cd sglang
109-
git apply ../benchmarks/third_party/sglang/sglang.patch
110-
pytest sglang/test/srt/test_triton_attention_kernels.py
111-
11299
- name: Upload test report
113100
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
114101
uses: actions/upload-artifact@v4

benchmarks/third_party/sglang/decode_attention_benchmark.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,36 @@
55
import triton_kernels_benchmark as benchmark_suit
66

77

8-
def gen_args(BATCH, N_CTX, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, dtype, device):
8+
def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
99

10-
total_tokens = BATCH * N_CTX
11-
sm_scale = 1.0 / (HEAD_DIM**0.5)
10+
total_tokens = B * N_CTX
11+
sm_scale = 1.0 / (D**0.5)
1212
max_kv_splits = 8
13-
num_kv_splits = torch.full((BATCH, ), 4, dtype=torch.int32, device=device)
13+
num_kv_splits = torch.full((B, ), 4, dtype=torch.int32, device=device)
1414

15-
# q represents the new token being generated, one per batch
16-
q = torch.randn(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
15+
# q represents the new token being generated, one per B
16+
q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
1717

1818
# k_buffer and v_buffer represent all previous tokens
19-
k_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
20-
v_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
19+
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
20+
v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
2121

2222
# o will have the same shape as q
23-
o = torch.zeros(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
23+
o = torch.zeros(B, H_Q, D, dtype=dtype, device=device)
2424

25-
b_seq_len = torch.full((BATCH, ), N_CTX, device=device)
25+
b_seq_len = torch.full((B, ), N_CTX, device=device)
2626

27-
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
28-
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len[:BATCH], dim=0)
27+
kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device)
28+
kv_indptr[1:B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
2929
kv_indices = torch.arange(total_tokens, device=device)
3030

3131
attn_logits = torch.empty(
32-
(BATCH, Q_HEAD_NUM, max_kv_splits, HEAD_DIM),
32+
(B, H_Q, max_kv_splits, D),
3333
dtype=torch.float32,
3434
device=device,
3535
)
3636
attn_lse = torch.empty(
37-
(BATCH, Q_HEAD_NUM, max_kv_splits),
37+
(B, H_Q, max_kv_splits),
3838
dtype=torch.float32,
3939
device=device,
4040
)
@@ -105,7 +105,7 @@ def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
105105
else:
106106
raise NotImplementedError(f'Unsupported provider {provider}')
107107

108-
tflops = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * N_CTX * D * (1e-12) / (ms * 1e-3)
108+
tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * (1e-12) / (ms * 1e-3)
109109
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)
110110

111111
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv

benchmarks/third_party/sglang/extended_attention_benchmark.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,53 +6,51 @@
66
import triton_kernels_benchmark as benchmark_suit
77

88

9-
def gen_args(BATCH, N_CTX, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, dtype, device):
9+
def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
1010

11-
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device=device)
12-
b_seq_len_extend = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device=device)
11+
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device)
12+
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device)
1313
b_seq_len = b_seq_len_prefix + b_seq_len_extend
1414
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
1515

16-
b_req_idx = torch.arange(BATCH, dtype=torch.int32, device=device)
17-
b_start_loc = torch.zeros((BATCH, ), dtype=torch.int32, device=device)
16+
b_req_idx = torch.arange(B, dtype=torch.int32, device=device)
17+
b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device)
1818
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
19-
b_start_loc_extend = torch.zeros((BATCH, ), dtype=torch.int32, device=device)
19+
b_start_loc_extend = torch.zeros((B, ), dtype=torch.int32, device=device)
2020
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
2121

22-
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
23-
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_prefix[:BATCH], dim=0)
22+
kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device)
23+
kv_indptr[1:B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
2424
kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device=device)
2525

26-
for i in range(BATCH):
26+
for i in range(B):
2727
kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i])
2828

2929
total_token_num = torch.sum(b_seq_len).item()
3030
extend_token_num = torch.sum(b_seq_len_extend).item()
31-
k_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
32-
device=device).normal_(mean=0.1, std=0.2)
33-
v_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
34-
device=device).normal_(mean=0.1, std=0.2)
35-
36-
k_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
37-
v_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
38-
q_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
39-
for i in range(BATCH):
31+
k_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2)
32+
v_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2)
33+
34+
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
35+
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
36+
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
37+
for i in range(B):
4038
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
4139
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
4240
extend_start = b_start_loc_extend[i]
4341
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
4442
k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer]
4543
v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer]
46-
q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], Q_HEAD_NUM, HEAD_DIM), dtype=dtype,
44+
q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], H_Q, D), dtype=dtype,
4745
device=device).normal_(mean=0.1, std=0.2)
4846

49-
o_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
50-
o_redundant = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
47+
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
48+
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
5149

5250
b_seq_len_extend = b_seq_len - b_seq_len_prefix
5351
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
54-
qo_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
55-
qo_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_extend[:BATCH], dim=0)
52+
qo_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device)
53+
qo_indptr[1:B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
5654

5755
params = []
5856
params.append((q_extend, k_extend, v_extend, o_extend, o_redundant))
@@ -127,8 +125,10 @@ def refer_fn():
127125
else:
128126
raise NotImplementedError(f'Unsupported provider {provider}')
129127

130-
tflops = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * N_CTX * D * (1e-12) / (ms * 1e-3)
131-
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)
128+
N_CTX_TOTAL = k_buffer.shape[0]
129+
N_CTX_EXTEND = k_extend.shape[0]
130+
tflops = lambda ms: (H_Q + H_KV) * (N_CTX_EXTEND + N_CTX_TOTAL) * N_CTX_TOTAL * D * (1e-12) / (ms * 1e-3)
131+
gbps = lambda ms: 2 * (N_CTX_EXTEND * (H_Q + H_KV) + N_CTX_TOTAL * H_KV) * D * 2 * (1e-9) / (ms * 1e-3)
132132

133133
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
134134

0 commit comments

Comments
 (0)