Skip to content

Commit a7a69e2

Browse files
committed
Address review comments
1 parent 71dad71 commit a7a69e2

File tree

4 files changed

+89
-53
lines changed

4 files changed

+89
-53
lines changed

.github/workflows/third-party-benchmarks.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ jobs:
120120
run: |
121121
git clone https://github.com/sgl-project/sglang.git
122122
cd sglang
123-
git apply ../benchmarks/third_party/sglang/sglang.patch
124-
pip install ./python[dev_xpu]
123+
git apply ../benchmarks/third_party/sglang/sglang-fix.patch
124+
pip install "./python[dev_xpu]"
125125
126126
# Reinstallation since SGLang installation will force overrides current PyTorch and Triton
127127
- name: Reinstall PyTorch
@@ -139,7 +139,7 @@ jobs:
139139
python prefill_attention_benchmark.py --reports $REPORTS
140140
141141
source ../../../scripts/capture-hw-details.sh
142-
python ../../triton_kernels_benchmark/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-prefill-attn-performance.csv $REPORTS/sglang-prefill-attn-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
143143
144144
- name: Run SGLANG attention decode stage benchmark
145145
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
@@ -149,7 +149,7 @@ jobs:
149149
python decode_attention_benchmark.py --reports $REPORTS
150150
151151
source ../../../scripts/capture-hw-details.sh
152-
python ../../triton_kernels_benchmark/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
152+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-decode-attn-performance.csv $REPORTS/sglang-decode-attn-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
153153
154154
- name: Run SGLANG attention append stage benchmark
155155
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
@@ -159,10 +159,10 @@ jobs:
159159
python extended_attention_benchmark.py --reports $REPORTS
160160
161161
source ../../../scripts/capture-hw-details.sh
162-
python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
162+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-extended-attn-performance.csv $REPORTS/sglang-append-attn-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
163163
164164
- name: Run SGLANG Block FP8 GEMM benchmark
165-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'block_fp8_gemm_benchmark.py') }}
165+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
166166
run: |
167167
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
168168
cd benchmarks/third_party/sglang

benchmarks/third_party/sglang/decode_attention_benchmark.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,29 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
4343
sm_scale)
4444

4545

46+
def get_dtype(dtype_str: str):
47+
if dtype_str == 'bfloat16':
48+
return torch.bfloat16
49+
if dtype_str == 'float16':
50+
return torch.float16
51+
if dtype_str == 'float32':
52+
return torch.float32
53+
raise ValueError(f'Unsupported dtype: {dtype_str}')
54+
55+
56+
X_VALS = [[bs, *sizes, mode, dtype]
57+
for sizes in [(1024 + 64, 32, 8, 128), (1024 + 64, 32, 32, 96), (1024 + 64, 28, 4, 128)]
58+
for bs in [1, 16, 32, 64, 128]
59+
for mode in ['fwd']
60+
for dtype in ['bfloat16']]
61+
62+
4663
# pylint: disable=unused-argument
4764
@benchmark_suit.perf_report(
4865
benchmark_suit.Benchmark(
4966
# argument names to use as an x-axis for the plot
50-
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE'],
51-
x_vals=[ #
52-
[bs, 1024 + 64, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
53-
] + [ #
54-
[bs, 1024 + 64, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128]
55-
] + [ #
56-
[bs, 1024 + 64, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
57-
],
67+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'DTYPE'],
68+
x_vals=X_VALS,
5869
line_arg='provider',
5970
# argument name whose value corresponds to a different line in the plot
6071
# possible values for `line_arg``
@@ -68,19 +79,19 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
6879
# line styles
6980
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
7081
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
71-
plot_name='decode-attn-performance',
82+
plot_name='sglang-decode-attn-performance',
7283
# name for the plot. Used also as a file name for saving the plot.
7384
args={},
7485
))
75-
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, provider):
86+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, DTYPE, provider):
7687
torch.manual_seed(0)
77-
dtype = torch.bfloat16
78-
quantiles = [0.5, 0.0, 1.0]
79-
N_CTX = SEQ_LENS
88+
dtype = get_dtype(DTYPE)
8089

90+
N_CTX = SEQ_LENS
8191
q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args(
8292
B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
8393

94+
quantiles = [0.5, 0.0, 1.0]
8495
if provider == 'triton' and MODE == 'fwd':
8596
triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse,
8697
num_kv_splits, max_kv_splits, sm_scale)
@@ -89,8 +100,8 @@ def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, provider):
89100
else:
90101
raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}')
91102

92-
tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * (1e-12) / (ms * 1e-3)
93-
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)
103+
tflops = lambda ms: B * N_CTX * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3)
104+
gbps = lambda ms: B * (H_Q + 2 * N_CTX * H_KV) * D * 2 * (1e-9) / (ms * 1e-3)
94105

95106
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
96107

benchmarks/third_party/sglang/extended_attention_benchmark.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66

77
# pylint: disable=unused-argument
8-
def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device):
8+
def gen_args(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, dtype, device):
99

1010
b_seq_len_prefix = torch.full((B, ), PREFIX_LEN, dtype=torch.int32, device=device)
11-
b_seq_len_extend = torch.full((B, ), Q_LEN, dtype=torch.int32, device=device)
11+
b_seq_len_extend = torch.full((B, ), EXTEND_LEN, dtype=torch.int32, device=device)
1212
b_seq_len = b_seq_len_prefix + b_seq_len_extend
1313

1414
b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device)
@@ -55,18 +55,31 @@ def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device):
5555
return params
5656

5757

58+
def get_dtype(dtype_str: str):
59+
if dtype_str == 'bfloat16':
60+
return torch.bfloat16
61+
if dtype_str == 'float16':
62+
return torch.float16
63+
if dtype_str == 'float32':
64+
return torch.float32
65+
raise ValueError(f'Unsupported dtype: {dtype_str}')
66+
67+
68+
X_VALS = [[bs, *sizes, mode, dtype]
69+
for sizes in [(512, 1024 + 128, 32, 8, 128), #
70+
(512, 1024 + 128, 32, 32,96), #
71+
(512, 1024 + 128, 28, 4, 128)]
72+
for bs in [1, 16, 32, 64, 128]
73+
for mode in ['fwd']
74+
for dtype in ['bfloat16']]
75+
76+
5877
# pylint: disable=unused-argument
5978
@benchmark_suit.perf_report(
6079
benchmark_suit.Benchmark(
6180
# argument names to use as an x-axis for the plot
62-
x_names=['B', 'Q_LEN', 'PREFIX_LEN', 'KV_LEN', 'H_Q', 'H_KV', 'D', 'MODE'],
63-
x_vals=[ #
64-
[bs, 512, 1024 + 128, 512, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
65-
] + [ #
66-
[bs, 512, 1024 + 128, 512, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128]
67-
] + [ #
68-
[bs, 512, 1024 + 128, 512, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
69-
],
81+
x_names=['B', 'EXTEND_LEN', 'PREFIX_LEN', 'H_Q', 'H_KV', 'D', 'MODE', 'DTYPE'],
82+
x_vals=X_VALS,
7083
line_arg='provider',
7184
# argument name whose value corresponds to a different line in the plot
7285
# possible values for `line_arg``
@@ -80,16 +93,15 @@ def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device):
8093
# line styles
8194
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
8295
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
83-
plot_name='extended-attn-performance',
96+
plot_name='sglang-extended-attn-performance',
8497
# name for the plot. Used also as a file name for saving the plot.
8598
args={},
8699
))
87-
def benchmark(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, MODE, provider):
100+
def benchmark(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, MODE, DTYPE, provider):
88101
torch.manual_seed(0)
102+
dtype = get_dtype(DTYPE)
89103

90-
dtype = torch.bfloat16
91-
92-
params = gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, 'xpu')
104+
params = gen_args(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, dtype, 'xpu')
93105
q_extend, k_extend, v_extend, o_extend = params[0]
94106
k_buffer, v_buffer = params[1]
95107
qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2]
@@ -105,10 +117,11 @@ def benchmark(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, MODE, provider):
105117
else:
106118
raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}')
107119

108-
N_CTX_TOTAL = k_buffer.shape[0]
109-
N_CTX_EXTEND = k_extend.shape[0]
110-
tflops = lambda ms: (H_Q + H_KV) * (N_CTX_EXTEND + N_CTX_TOTAL) * N_CTX_TOTAL * D * (1e-12) / (ms * 1e-3)
111-
gbps = lambda ms: 2 * (N_CTX_EXTEND * (H_Q + H_KV) + N_CTX_TOTAL * H_KV) * D * 2 * (1e-9) / (ms * 1e-3)
120+
N_CTX_TOTAL = PREFIX_LEN + EXTEND_LEN
121+
N_CTX_EXTEND = EXTEND_LEN
122+
123+
tflops = lambda ms: B * (N_CTX_EXTEND + N_CTX_TOTAL) * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3)
124+
gbps = lambda ms: B * ((H_Q * N_CTX_EXTEND) + H_KV * (N_CTX_EXTEND + N_CTX_TOTAL) * 2) * D * 2 * (1e-9) / (ms * 1e-3)
112125

113126
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
114127

benchmarks/third_party/sglang/prefill_attention_benchmark.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,30 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
2222
return (q, k, v, o, b_start_loc, b_seq_len, max_seq_len)
2323

2424

25+
def get_dtype(dtype_str: str):
26+
if dtype_str == 'bfloat16':
27+
return torch.bfloat16
28+
if dtype_str == 'float16':
29+
return torch.float16
30+
if dtype_str == 'float32':
31+
return torch.float32
32+
raise ValueError(f'Unsupported dtype: {dtype_str}')
33+
34+
35+
X_VALS = [[bs, *sizes, causal, mode, dtype]
36+
for bs in [1, 16, 32, 64, 128]
37+
for sizes in [(1024, 32, 8, 128), (1024, 32, 32, 96), (1024, 28, 4, 128)]
38+
for causal in [True, False]
39+
for mode in ['fwd']
40+
for dtype in ['bfloat16']]
41+
42+
2543
# pylint: disable=unused-argument
2644
@benchmark_suit.perf_report(
2745
benchmark_suit.Benchmark(
2846
# argument names to use as an x-axis for the plot
29-
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE'],
30-
x_vals=[ #
31-
[bs, 1024, 32, 8, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
32-
] + [ #
33-
[bs, 1024, 32, 32, 96, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
34-
] + [ #
35-
[bs, 1024, 28, 4, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
36-
],
47+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE', 'DTYPE'],
48+
x_vals=X_VALS,
3749
line_arg='provider',
3850
# argument name whose value corresponds to a different line in the plot
3951
# possible values for `line_arg``
@@ -47,13 +59,13 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
4759
# line styles
4860
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
4961
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
50-
plot_name='prefill-attn-performance',
62+
plot_name='sglang-prefill-attn-performance',
5163
# name for the plot. Used also as a file name for saving the plot.
5264
args={},
5365
))
54-
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, provider):
66+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, DTYPE, provider):
5567
torch.manual_seed(0)
56-
dtype = torch.bfloat16
68+
dtype = get_dtype(DTYPE)
5769

5870
q, k, v, o, b_start_loc, b_seq_len, max_seq_len = gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, 'xpu')
5971

@@ -66,8 +78,8 @@ def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, provider):
6678
raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}')
6779

6880
N_CTX = SEQ_LENS
69-
tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * N_CTX * D * (1e-12) / (ms * 1e-3)
70-
gbps = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * 2 * (1e-9) / (ms * 1e-3)
81+
tflops = lambda ms: B * N_CTX * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3)
82+
gbps = lambda ms: B * N_CTX * (H_Q + 2 * H_KV) * D * 2 * (1e-9) / (ms * 1e-3)
7183

7284
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
7385

0 commit comments

Comments
 (0)