From 228ac8ed0e4ef6561a6ed3a4f9275e04b04909d8 Mon Sep 17 00:00:00 2001 From: "Ling, Liyang" Date: Tue, 21 Jan 2025 09:04:38 +0000 Subject: [PATCH 1/5] Integrate sglang prefill/decode/extend kernel to benchmarks Port prefill attn and decode attn from sglang Add validation temp add extend attention disable debug ir dump Update three stage attention benchmark Add sglang kernel benchmark to action use 1e-3 atol remove sglang benchmark from triton-benchmarks Fix setup bdist_wheel Add sglang to thirdparty test Address review comments Remove sglang from tests Fix CI Address review comments Integrate sglang prefill/decode/extend kernel to benchmarks Port prefill attn and decode attn from sglang Add validation temp add extend attention disable debug ir dump Update three stage attention benchmark Add sglang kernel benchmark to action use 1e-3 atol remove sglang benchmark from triton-benchmarks Fix setup bdist_wheel Add sglang to thirdparty test Address review comments Remove sglang from tests Adjust params term Adjust tflops computation --- .github/workflows/third-party-benchmarks.yml | 42 +++++- .../sglang/decode_attention_benchmark.py | 115 +++++++++++++++ .../sglang/extended_attention_benchmark.py | 137 ++++++++++++++++++ .../sglang/prefill_attention_benchmark.py | 98 +++++++++++++ .../triton_kernels_benchmark/build_report.py | 2 +- 5 files changed, 391 insertions(+), 3 deletions(-) create mode 100644 benchmarks/third_party/sglang/decode_attention_benchmark.py create mode 100644 benchmarks/third_party/sglang/extended_attention_benchmark.py create mode 100644 benchmarks/third_party/sglang/prefill_attention_benchmark.py diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index ccb53e7077..3d9787504d 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -72,8 +72,14 @@ jobs: - name: Setup Triton uses: ./.github/actions/setup-triton - - name: Install benchmark dependencies + - name: Install benchmarks id: install + run: | + cd benchmarks + pip install . + + - name: Install benchmark dependencies + id: install_deps run: | pip install transformers pandas pytest @@ -83,7 +89,7 @@ jobs: echo "REPORTS=$PWD/reports" >> $GITHUB_ENV - name: Run Liger-Kernel benchmarks - if: ${{ steps.install.outcome == 'success' && !cancelled() }} + if: ${{ steps.install_deps.outcome == 'success' && !cancelled() }} run: | source ./scripts/capture-hw-details.sh @@ -102,6 +108,38 @@ jobs: # Return the captured return code at the end exit "$RET_CODE" + - name: Install SGLANG + run: | + git clone https://github.com/sgl-project/sglang.git + pip install sglang/python[srt_xpu] + + - name: Run SGLANG attention prefill stage benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + cd benchmarks/third_party/sglang + python prefill_attention_benchmark.py --reports $REPORTS + + source ../../../scripts/capture-hw-details.sh + python ../../triton_kernels_benchmark/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + + - name: Run SGLANG attention decode stage benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + cd benchmarks/third_party/sglang + python decode_attention_benchmark.py --reports $REPORTS + + source ../../../scripts/capture-hw-details.sh + python ../../triton_kernels_benchmark/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + + - name: Run SGLANG attention append stage benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + cd benchmarks/third_party/sglang + python extended_attention_benchmark.py --reports $REPORTS + + source ../../../scripts/capture-hw-details.sh + python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + - name: Upload benchmark reports if: ${{ steps.install.outcome == 'success' && !cancelled() }} uses: actions/upload-artifact@v4 diff --git a/benchmarks/third_party/sglang/decode_attention_benchmark.py b/benchmarks/third_party/sglang/decode_attention_benchmark.py new file mode 100644 index 0000000000..16c932db9b --- /dev/null +++ b/benchmarks/third_party/sglang/decode_attention_benchmark.py @@ -0,0 +1,115 @@ +import torch + +from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd + +import triton_kernels_benchmark as benchmark_suit + + +def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): + + total_tokens = B * N_CTX + sm_scale = 1.0 / (D**0.5) + max_kv_splits = 8 + num_kv_splits = torch.full((B, ), 4, dtype=torch.int32, device=device) + + # q represents the new token being generated, one per B + q = torch.randn(B, H_Q, D, dtype=dtype, device=device) + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) + v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D, dtype=dtype, device=device) + + b_seq_len = torch.full((B, ), N_CTX, device=device) + + kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device) + kv_indptr[1:B + 1] = torch.cumsum(b_seq_len[:B], dim=0) + kv_indices = torch.arange(total_tokens, device=device) + + attn_logits = torch.empty( + (B, H_Q, max_kv_splits, D), + dtype=torch.float32, + device=device, + ) + attn_lse = torch.empty( + (B, H_Q, max_kv_splits), + dtype=torch.float32, + device=device, + ) + + return (q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, + sm_scale) + + +# pylint: disable=unused-argument +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'], + x_vals=[ # + [bs, [1024, 64], 32, 8, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128] + ] + [ # + [bs, [1024, 64], 32, 32, 96, 'fwd', False] for bs in [1, 16, 32, 64, 128] + ] + [ # + [bs, [1024, 64], 28, 4, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128] + ], + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=[ + 'triton', + ], + # label name for the lines + line_names=[ + 'Triton', + ], + # line styles + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], + ylabel=['GB/s', 'TFlops'], # label name for the y-axis + plot_name='decode-attn-performance', + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider): + torch.manual_seed(0) + dtype = torch.bfloat16 + quantiles = [0.5, 0.0, 1.0] + N_CTX = sum(SEQ_LENS) + + q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args( + B, N_CTX, H_Q, H_KV, D, dtype, 'xpu') + + if provider == 'triton': + triton_fn = lambda: decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + attn_lse, + num_kv_splits, + max_kv_splits, + sm_scale, + ) + + # TODO: decode attention should have the validation function + if VALIDATE: + raise NotImplementedError('Validation is not implemented for decode stage') + + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) + + else: + raise NotImplementedError(f'Unsupported provider {provider}') + + tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * (1e-12) / (ms * 1e-3) + gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3) + + return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == '__main__': + benchmark.run(show_plots=False, print_data=True) diff --git a/benchmarks/third_party/sglang/extended_attention_benchmark.py b/benchmarks/third_party/sglang/extended_attention_benchmark.py new file mode 100644 index 0000000000..da6f392fdd --- /dev/null +++ b/benchmarks/third_party/sglang/extended_attention_benchmark.py @@ -0,0 +1,137 @@ +import torch +from sglang.srt.layers.attention.triton_ops.extend_attention import ( + extend_attention_fwd, + redundant_attention, +) +import triton_kernels_benchmark as benchmark_suit + + +def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): + + b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device) + b_seq_len_extend = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() + + b_req_idx = torch.arange(B, dtype=torch.int32, device=device) + b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B, ), dtype=torch.int32, device=device) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device) + kv_indptr[1:B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device=device) + + for i in range(B): + kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer] + v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer] + q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], H_Q, D), dtype=dtype, + device=device).normal_(mean=0.1, std=0.2) + + o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + qo_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device) + qo_indptr[1:B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + + params = [] + params.append((q_extend, k_extend, v_extend, o_extend, o_redundant)) + params.append((k_buffer, v_buffer)) + params.append((qo_indptr, kv_indptr, kv_indices, max_len_extend)) + params.append((b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch)) + return params + + +# pylint: disable=unused-argument +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'], + x_vals=[ # + [bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128] + ] + [ # + [bs, [1024, 128, 512], 32, 32, 96, 'fwd', True] for bs in [1, 16, 32, 64, 128] + ] + [ # + [bs, [1024, 128, 512], 28, 4, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128] + ], + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=[ + 'triton', + ], + # label name for the lines + line_names=[ + 'Triton', + ], + # line styles + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], + ylabel=['GB/s', 'TFlops'], # label name for the y-axis + plot_name='extended-attn-performance', + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider): + torch.manual_seed(0) + + dtype = torch.bfloat16 + N_CTX = sum(SEQ_LENS) + + params = gen_args(B, N_CTX, H_Q, H_KV, D, dtype, 'xpu') + q_extend, k_extend, v_extend, o_extend, o_redundant = params[0] + k_buffer, v_buffer = params[1] + qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2] + b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch = params[3] + custom_mask = None + mask_indptr = None + + quantiles = [0.5, 0.0, 1.0] + if provider == 'triton': + + def triton_fn(): + extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, + kv_indices, custom_mask, mask_indptr, max_len_extend) + return o_extend + + if VALIDATE: + + def refer_fn(): + redundant_attention(q_extend, o_redundant, k_buffer, v_buffer, b_req_idx, b_start_loc, b_seq_len, + b_seq_len_prefix, max_len_in_batch) + return o_redundant + + benchmark_suit.assert_close(triton_fn, refer_fn, atol=1e-3, rtol=1e-2, err_msg='extend to refer') + + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) + + else: + raise NotImplementedError(f'Unsupported provider {provider}') + + N_CTX_TOTAL = k_buffer.shape[0] + N_CTX_EXTEND = k_extend.shape[0] + tflops = lambda ms: (H_Q + H_KV) * (N_CTX_EXTEND + N_CTX_TOTAL) * N_CTX_TOTAL * D * (1e-12) / (ms * 1e-3) + gbps = lambda ms: 2 * (N_CTX_EXTEND * (H_Q + H_KV) + N_CTX_TOTAL * H_KV) * D * 2 * (1e-9) / (ms * 1e-3) + + return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == '__main__': + benchmark.run(show_plots=False, print_data=True) diff --git a/benchmarks/third_party/sglang/prefill_attention_benchmark.py b/benchmarks/third_party/sglang/prefill_attention_benchmark.py new file mode 100644 index 0000000000..e59e76cf5d --- /dev/null +++ b/benchmarks/third_party/sglang/prefill_attention_benchmark.py @@ -0,0 +1,98 @@ +import torch + +from sglang.srt.layers.attention.triton_ops.prefill_attention import context_attention_fwd + +import triton_kernels_benchmark as benchmark_suit + + +def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device): + max_seq_len = max(SEQ_LENS) + N_CTX = sum(SEQ_LENS) + + # Create random input tensors + q = torch.randn((B * N_CTX, H_Q, D), device=device, dtype=dtype) + k = torch.randn((B * N_CTX, H_KV, D), device=device, dtype=dtype) + v = torch.randn((B * N_CTX, H_KV, D), device=device, dtype=dtype) + o = torch.zeros((B * N_CTX, H_Q, D), device=device, dtype=dtype) + + # Create b_start_loc and b_seq_len tensors + b_start_loc = torch.tensor([0, SEQ_LENS[0]], device=device) + b_seq_len = torch.tensor(SEQ_LENS, device=device) + + return (q, k, v, o, b_start_loc, b_seq_len, max_seq_len) + + +# pylint: disable=unused-argument +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE', 'VALIDATE'], + x_vals=[ # + [bs, [1024], 32, 8, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128] + ] + [ # + [bs, [1024], 32, 32, 96, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128] + ] + [ # + [bs, [1024], 28, 4, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128] + ], + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=[ + 'triton', + ], + # label name for the lines + line_names=[ + 'Triton', + ], + # line styles + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], + ylabel=['GB/s', 'TFlops'], # label name for the y-axis + plot_name='prefill-attn-performance', + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, VALIDATE, provider): + torch.manual_seed(0) + dtype = torch.bfloat16 + N_CTX = sum(SEQ_LENS) + + q, k, v, o, b_start_loc, b_seq_len, max_seq_len = gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, 'xpu') + + quantiles = [0.5, 0.0, 1.0] + if provider == 'triton': + + triton_fn = lambda: context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=CAUSAL) + + if VALIDATE: + # FIXME: torch sdpa does not support different H_Q and H_KV + cu_seq_lens = [0] * (len(SEQ_LENS) + 1) + for i, seq_len in enumerate(SEQ_LENS): + cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len + + for i in range(len(SEQ_LENS)): + start, end = cu_seq_lens[i], cu_seq_lens[i + 1] + o_torch = torch.nn.functional.scaled_dot_product_attention( + q[start:end].permute(1, 0, 2), + k[start:end].permute(1, 0, 2), + v[start:end].permute(1, 0, 2), + is_causal=CAUSAL, + ).permute(1, 0, 2) + + cos_sim = torch.nn.functional.cosine_similarity(o[start:end].flatten(), o_torch.flatten(), dim=0) + assert cos_sim.item() > 1 - (1e-5) + assert torch.allclose(o[start:end], o_torch, atol=1e-2) + + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles) + + else: + raise NotImplementedError(f'Unsupported provider {provider}') + + tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * N_CTX * D * (1e-12) / (ms * 1e-3) + gbps = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * 2 * (1e-9) / (ms * 1e-3) + + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == '__main__': + benchmark.run(show_plots=False, print_data=True) diff --git a/benchmarks/triton_kernels_benchmark/build_report.py b/benchmarks/triton_kernels_benchmark/build_report.py index 87b89316ff..a0a5590ea6 100644 --- a/benchmarks/triton_kernels_benchmark/build_report.py +++ b/benchmarks/triton_kernels_benchmark/build_report.py @@ -90,7 +90,7 @@ def build_report(args: PassedArgs, results_df: Optional[pd.DataFrame] = None): df[p] = df[p].astype(int) df_results["params"] = [json.dumps(j) for j in df[[*param_cols, "MASK"]].to_dict("records")] else: - df_results["params"] = [json.dumps(j) for j in df[param_cols].astype(int).to_dict("records")] + df_results["params"] = [json.dumps(j) for j in df[param_cols].astype(str).to_dict("records")] df_results["tflops"] = df[args.tflops_col] if hbm_col is not None: df_results["hbm_gbs"] = df[hbm_col] From d4b7d956742e455bee8794c61068d75342394d19 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 12 Mar 2025 08:30:32 +0000 Subject: [PATCH 2/5] add sglang block fp8 gemm into benchmark fix bugs rtol atol Move fp8 gemm to sglang benchmark --- .github/workflows/third-party-benchmarks.yml | 11 +- .../sglang/block_fp8_gemm_benchmark.py | 335 ++++++++++++++++++ .../sglang/decode_attention_benchmark.py | 15 +- 3 files changed, 347 insertions(+), 14 deletions(-) create mode 100644 benchmarks/third_party/sglang/block_fp8_gemm_benchmark.py diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 3d9787504d..31a12ae2d3 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -111,7 +111,7 @@ jobs: - name: Install SGLANG run: | git clone https://github.com/sgl-project/sglang.git - pip install sglang/python[srt_xpu] + pip install sglang/python[dev_xpu] - name: Run SGLANG attention prefill stage benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} @@ -140,6 +140,15 @@ jobs: source ../../../scripts/capture-hw-details.sh python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + - name: Run SGLANG Block FP8 GEMM benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'block_fp8_gemm_benchmark.py') }} + run: | + cd benchmarks/third_party/sglang + python block_fp8_gemm_benchmark.py --reports $REPORTS + + source ../../../scripts/capture-hw-details.sh + python ../../../scripts/build_report.py $REPORTS/sglang-fp8-gemm-performance.csv $REPORTS/sglang-fp8-gemm-triton-report.csv --benchmark sglang-block-fp8-gemm --compiler triton --param_cols "M,N,K" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + - name: Upload benchmark reports if: ${{ steps.install.outcome == 'success' && !cancelled() }} uses: actions/upload-artifact@v4 diff --git a/benchmarks/third_party/sglang/block_fp8_gemm_benchmark.py b/benchmarks/third_party/sglang/block_fp8_gemm_benchmark.py new file mode 100644 index 0000000000..a99e4a0d82 --- /dev/null +++ b/benchmarks/third_party/sglang/block_fp8_gemm_benchmark.py @@ -0,0 +1,335 @@ +""" +Block FP8 Gemm benchmark +============================ +This benchmark is come from SGLang kernels. +https://github.com/sgl-project/sglang/blob/07f944631e747d7489fde1f11de93e503afa90ba/python/sglang/srt/layers/quantization/fp8_kernel.py#L375 +""" + +from typing import List + +import torch +import triton +import triton.language as tl + +import triton_kernels_benchmark as benchmark_suit + +DEVICE_NAME = torch.xpu.get_device_name() +DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + kernel = _w8a8_block_fp8_matmul + + kernel[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C + + +# For test +def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise quantization using native torch. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + """ + + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def has_enough_memory(x_val): + # x_val: (M, N, K) + M, N, K = x_val + # a: (M, K) float8_e4m3 + # b: (N, K) float8_e4m3 + # c: (M, N) bfloat16 + # pytorch reference: (M, N) float32 + required_memory = M * K * 1 + N * K * 1 + M * N * 2 * 2 + enough_memory = required_memory < DEVICE_TOTAL_MEMORY + if not enough_memory: + print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}") + return enough_memory + + +X_VALS = [[1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [ + [1, 13824, 5120], + [4, 12288, 4096], + [512, 8192, 8192], + [512, 8192, 32768], + [512, 32768, 8192], + [1024, 8192, 16384], + [1024, 8192, 28672], + [3072, 3072, 4096], + [4096, 8192, 16384], + [8192, 1024, 16384], + [8192, 4096, 16384], + [16384, 1024, 8192], + [16384, 4096, 8192], + [16384, 8192, 1024], + [16384, 8192, 4096], + [32768, 128, 4096], + [32768, 4096, 128], + [4096, 128, 4096], + [8, 128, 16384], + [8, 16384, 128], +] + +X_VALS = [x_val for x_val in X_VALS if has_enough_memory(x_val)] + + +# Benchmark Performance +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=["M", "N", "K"], + # different possible values for `x_name` + x_vals=X_VALS, + line_arg="provider", + # argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + # label name for the lines + line_names=["Triton"], + # line styles + ylabel=["GB/s", "TFlops"], # label name for the y-axis + plot_name="sglang-fp8-gemm-performance", + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + torch.manual_seed(0) + + block_size = [128, 128] + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale + + quantiles = [0.5, 0.0, 1.0] + + if provider == "triton": + triton_fn = lambda: w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size) + torch_fn = lambda: native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size) + benchmark_suit.assert_close(triton_fn, torch_fn, atol=3e-4, rtol=1e-2, err_msg="triton to torch") + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles) + + else: + raise NotImplementedError(f"Unsupported provider {provider}") + + tflops = lambda ms: 2 * M * N * K * (1e-12) / (ms * 1e-3) + gbps = lambda ms: (M * K + K * N) + 2.0 * (M * N) * (1e-9) / (ms * 1e-3) + + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == "__main__": + benchmark.run(show_plots=False, print_data=True) diff --git a/benchmarks/third_party/sglang/decode_attention_benchmark.py b/benchmarks/third_party/sglang/decode_attention_benchmark.py index 16c932db9b..b41c925b98 100644 --- a/benchmarks/third_party/sglang/decode_attention_benchmark.py +++ b/benchmarks/third_party/sglang/decode_attention_benchmark.py @@ -82,19 +82,8 @@ def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider): B, N_CTX, H_Q, H_KV, D, dtype, 'xpu') if provider == 'triton': - triton_fn = lambda: decode_attention_fwd( - q, - k_buffer, - v_buffer, - o, - kv_indptr, - kv_indices, - attn_logits, - attn_lse, - num_kv_splits, - max_kv_splits, - sm_scale, - ) + triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, + num_kv_splits, max_kv_splits, sm_scale) # TODO: decode attention should have the validation function if VALIDATE: From 71dad713f12a928ea10aae993740cf2b782739df Mon Sep 17 00:00:00 2001 From: "Ling, Liyang" Date: Thu, 22 May 2025 08:21:32 +0000 Subject: [PATCH 3/5] Update extended attention interface Address review comments Fix CI XPU not found --- .github/workflows/third-party-benchmarks.yml | 40 ++++++++++---- .../sglang/decode_attention_benchmark.py | 21 +++----- .../sglang/extended_attention_benchmark.py | 54 ++++++------------- .../sglang/prefill_attention_benchmark.py | 46 +++++----------- .../triton_kernels_benchmark/build_report.py | 2 +- 5 files changed, 69 insertions(+), 94 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 31a12ae2d3..1a50a27438 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -71,25 +71,33 @@ jobs: - name: Setup Triton uses: ./.github/actions/setup-triton + with: + command: DEBUG=1 python setup.py bdist_wheel - - name: Install benchmarks - id: install + - name: Install Triton run: | - cd benchmarks - pip install . + pip install dist/*.whl - name: Install benchmark dependencies - id: install_deps + id: install run: | pip install transformers pandas pytest + cd benchmarks + pip install . + pip install intel-pti==0.12.2 + PTI_LIBS_DIR=$(python -c "import sysconfig; print(sysconfig.get_paths()['stdlib']+'/..')") + # the output should contain: `libpti.so`, `libpti_metrics.so.0.12.2` and `libpti_view.so.0.12.2` + ls $PTI_LIBS_DIR + echo "PTI_LIBS_DIR=$PTI_LIBS_DIR" >> $GITHUB_ENV + - name: Create reports dir run: | mkdir reports echo "REPORTS=$PWD/reports" >> $GITHUB_ENV - name: Run Liger-Kernel benchmarks - if: ${{ steps.install_deps.outcome == 'success' && !cancelled() }} + if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | source ./scripts/capture-hw-details.sh @@ -111,11 +119,22 @@ jobs: - name: Install SGLANG run: | git clone https://github.com/sgl-project/sglang.git - pip install sglang/python[dev_xpu] + cd sglang + git apply ../benchmarks/third_party/sglang/sglang.patch + pip install ./python[dev_xpu] + + # Reinstallation since SGLang installation will force overrides current PyTorch and Triton + - name: Reinstall PyTorch + uses: ./.github/actions/setup-pytorch + + - name: Reinstall Triton + run: | + pip install ./dist/*.whl - name: Run SGLANG attention prefill stage benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/third_party/sglang python prefill_attention_benchmark.py --reports $REPORTS @@ -125,6 +144,7 @@ jobs: - name: Run SGLANG attention decode stage benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/third_party/sglang python decode_attention_benchmark.py --reports $REPORTS @@ -134,20 +154,22 @@ jobs: - name: Run SGLANG attention append stage benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/third_party/sglang python extended_attention_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh - python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - name: Run SGLANG Block FP8 GEMM benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'block_fp8_gemm_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/third_party/sglang python block_fp8_gemm_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh - python ../../../scripts/build_report.py $REPORTS/sglang-fp8-gemm-performance.csv $REPORTS/sglang-fp8-gemm-triton-report.csv --benchmark sglang-block-fp8-gemm --compiler triton --param_cols "M,N,K" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-fp8-gemm-performance.csv $REPORTS/sglang-fp8-gemm-triton-report.csv --benchmark sglang-block-fp8-gemm --compiler triton --param_cols "M,N,K" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - name: Upload benchmark reports if: ${{ steps.install.outcome == 'success' && !cancelled() }} diff --git a/benchmarks/third_party/sglang/decode_attention_benchmark.py b/benchmarks/third_party/sglang/decode_attention_benchmark.py index b41c925b98..b9edc2af80 100644 --- a/benchmarks/third_party/sglang/decode_attention_benchmark.py +++ b/benchmarks/third_party/sglang/decode_attention_benchmark.py @@ -47,13 +47,13 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'], + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE'], x_vals=[ # - [bs, [1024, 64], 32, 8, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128] + [bs, 1024 + 64, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128] ] + [ # - [bs, [1024, 64], 32, 32, 96, 'fwd', False] for bs in [1, 16, 32, 64, 128] + [bs, 1024 + 64, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128] ] + [ # - [bs, [1024, 64], 28, 4, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128] + [bs, 1024 + 64, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128] ], line_arg='provider', # argument name whose value corresponds to a different line in the plot @@ -72,27 +72,22 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): # name for the plot. Used also as a file name for saving the plot. args={}, )) -def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider): +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, provider): torch.manual_seed(0) dtype = torch.bfloat16 quantiles = [0.5, 0.0, 1.0] - N_CTX = sum(SEQ_LENS) + N_CTX = SEQ_LENS q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args( B, N_CTX, H_Q, H_KV, D, dtype, 'xpu') - if provider == 'triton': + if provider == 'triton' and MODE == 'fwd': triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale) - - # TODO: decode attention should have the validation function - if VALIDATE: - raise NotImplementedError('Validation is not implemented for decode stage') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: - raise NotImplementedError(f'Unsupported provider {provider}') + raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * (1e-12) / (ms * 1e-3) gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3) diff --git a/benchmarks/third_party/sglang/extended_attention_benchmark.py b/benchmarks/third_party/sglang/extended_attention_benchmark.py index da6f392fdd..4624d4649b 100644 --- a/benchmarks/third_party/sglang/extended_attention_benchmark.py +++ b/benchmarks/third_party/sglang/extended_attention_benchmark.py @@ -1,19 +1,16 @@ import torch from sglang.srt.layers.attention.triton_ops.extend_attention import ( - extend_attention_fwd, - redundant_attention, -) + extend_attention_fwd, ) import triton_kernels_benchmark as benchmark_suit -def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): +# pylint: disable=unused-argument +def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device): - b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device) - b_seq_len_extend = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device) + b_seq_len_prefix = torch.full((B, ), PREFIX_LEN, dtype=torch.int32, device=device) + b_seq_len_extend = torch.full((B, ), Q_LEN, dtype=torch.int32, device=device) b_seq_len = b_seq_len_prefix + b_seq_len_extend - max_len_in_batch = torch.max(b_seq_len, 0)[0].item() - b_req_idx = torch.arange(B, dtype=torch.int32, device=device) b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device) b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) b_start_loc_extend = torch.zeros((B, ), dtype=torch.int32, device=device) @@ -45,7 +42,6 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): device=device).normal_(mean=0.1, std=0.2) o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) - o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) b_seq_len_extend = b_seq_len - b_seq_len_prefix max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() @@ -53,10 +49,9 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): qo_indptr[1:B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) params = [] - params.append((q_extend, k_extend, v_extend, o_extend, o_redundant)) + params.append((q_extend, k_extend, v_extend, o_extend)) params.append((k_buffer, v_buffer)) params.append((qo_indptr, kv_indptr, kv_indices, max_len_extend)) - params.append((b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch)) return params @@ -64,13 +59,13 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'], + x_names=['B', 'Q_LEN', 'PREFIX_LEN', 'KV_LEN', 'H_Q', 'H_KV', 'D', 'MODE'], x_vals=[ # - [bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128] + [bs, 512, 1024 + 128, 512, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128] ] + [ # - [bs, [1024, 128, 512], 32, 32, 96, 'fwd', True] for bs in [1, 16, 32, 64, 128] + [bs, 512, 1024 + 128, 512, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128] ] + [ # - [bs, [1024, 128, 512], 28, 4, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128] + [bs, 512, 1024 + 128, 512, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128] ], line_arg='provider', # argument name whose value corresponds to a different line in the plot @@ -89,41 +84,26 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): # name for the plot. Used also as a file name for saving the plot. args={}, )) -def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider): +def benchmark(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, MODE, provider): torch.manual_seed(0) dtype = torch.bfloat16 - N_CTX = sum(SEQ_LENS) - params = gen_args(B, N_CTX, H_Q, H_KV, D, dtype, 'xpu') - q_extend, k_extend, v_extend, o_extend, o_redundant = params[0] + params = gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, 'xpu') + q_extend, k_extend, v_extend, o_extend = params[0] k_buffer, v_buffer = params[1] qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2] - b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch = params[3] custom_mask = None mask_indptr = None quantiles = [0.5, 0.0, 1.0] - if provider == 'triton': - - def triton_fn(): - extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, - kv_indices, custom_mask, mask_indptr, max_len_extend) - return o_extend - - if VALIDATE: - - def refer_fn(): - redundant_attention(q_extend, o_redundant, k_buffer, v_buffer, b_req_idx, b_start_loc, b_seq_len, - b_seq_len_prefix, max_len_in_batch) - return o_redundant - - benchmark_suit.assert_close(triton_fn, refer_fn, atol=1e-3, rtol=1e-2, err_msg='extend to refer') - + if provider == 'triton' and MODE == 'fwd': + triton_fn = lambda: extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, + kv_indptr, kv_indices, custom_mask, True, mask_indptr, max_len_extend) _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: - raise NotImplementedError(f'Unsupported provider {provider}') + raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') N_CTX_TOTAL = k_buffer.shape[0] N_CTX_EXTEND = k_extend.shape[0] diff --git a/benchmarks/third_party/sglang/prefill_attention_benchmark.py b/benchmarks/third_party/sglang/prefill_attention_benchmark.py index e59e76cf5d..205a68c2f0 100644 --- a/benchmarks/third_party/sglang/prefill_attention_benchmark.py +++ b/benchmarks/third_party/sglang/prefill_attention_benchmark.py @@ -6,8 +6,8 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device): - max_seq_len = max(SEQ_LENS) - N_CTX = sum(SEQ_LENS) + max_seq_len = SEQ_LENS + N_CTX = SEQ_LENS # Create random input tensors q = torch.randn((B * N_CTX, H_Q, D), device=device, dtype=dtype) @@ -16,8 +16,8 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device): o = torch.zeros((B * N_CTX, H_Q, D), device=device, dtype=dtype) # Create b_start_loc and b_seq_len tensors - b_start_loc = torch.tensor([0, SEQ_LENS[0]], device=device) - b_seq_len = torch.tensor(SEQ_LENS, device=device) + b_start_loc = torch.tensor([0, SEQ_LENS], device=device) + b_seq_len = torch.tensor([SEQ_LENS], device=device) return (q, k, v, o, b_start_loc, b_seq_len, max_seq_len) @@ -26,13 +26,13 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device): @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE', 'VALIDATE'], + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE'], x_vals=[ # - [bs, [1024], 32, 8, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128] + [bs, 1024, 32, 8, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128] ] + [ # - [bs, [1024], 32, 32, 96, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128] + [bs, 1024, 32, 32, 96, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128] ] + [ # - [bs, [1024], 28, 4, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128] + [bs, 1024, 28, 4, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128] ], line_arg='provider', # argument name whose value corresponds to a different line in the plot @@ -51,43 +51,21 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device): # name for the plot. Used also as a file name for saving the plot. args={}, )) -def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, VALIDATE, provider): +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, provider): torch.manual_seed(0) dtype = torch.bfloat16 - N_CTX = sum(SEQ_LENS) q, k, v, o, b_start_loc, b_seq_len, max_seq_len = gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, 'xpu') quantiles = [0.5, 0.0, 1.0] - if provider == 'triton': - + if provider == 'triton' and MODE == 'fwd': triton_fn = lambda: context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=CAUSAL) - - if VALIDATE: - # FIXME: torch sdpa does not support different H_Q and H_KV - cu_seq_lens = [0] * (len(SEQ_LENS) + 1) - for i, seq_len in enumerate(SEQ_LENS): - cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len - - for i in range(len(SEQ_LENS)): - start, end = cu_seq_lens[i], cu_seq_lens[i + 1] - o_torch = torch.nn.functional.scaled_dot_product_attention( - q[start:end].permute(1, 0, 2), - k[start:end].permute(1, 0, 2), - v[start:end].permute(1, 0, 2), - is_causal=CAUSAL, - ).permute(1, 0, 2) - - cos_sim = torch.nn.functional.cosine_similarity(o[start:end].flatten(), o_torch.flatten(), dim=0) - assert cos_sim.item() > 1 - (1e-5) - assert torch.allclose(o[start:end], o_torch, atol=1e-2) - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) - else: - raise NotImplementedError(f'Unsupported provider {provider}') + raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') + N_CTX = SEQ_LENS tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * N_CTX * D * (1e-12) / (ms * 1e-3) gbps = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * 2 * (1e-9) / (ms * 1e-3) diff --git a/benchmarks/triton_kernels_benchmark/build_report.py b/benchmarks/triton_kernels_benchmark/build_report.py index a0a5590ea6..87b89316ff 100644 --- a/benchmarks/triton_kernels_benchmark/build_report.py +++ b/benchmarks/triton_kernels_benchmark/build_report.py @@ -90,7 +90,7 @@ def build_report(args: PassedArgs, results_df: Optional[pd.DataFrame] = None): df[p] = df[p].astype(int) df_results["params"] = [json.dumps(j) for j in df[[*param_cols, "MASK"]].to_dict("records")] else: - df_results["params"] = [json.dumps(j) for j in df[param_cols].astype(str).to_dict("records")] + df_results["params"] = [json.dumps(j) for j in df[param_cols].astype(int).to_dict("records")] df_results["tflops"] = df[args.tflops_col] if hbm_col is not None: df_results["hbm_gbs"] = df[hbm_col] From c61d0ea539ab9d662016b9d82bc473ae53e9a308 Mon Sep 17 00:00:00 2001 From: "Ling, Liyang" Date: Thu, 29 May 2025 02:32:54 +0000 Subject: [PATCH 4/5] Address review comments --- .github/workflows/third-party-benchmarks.yml | 12 ++--- .../sglang/decode_attention_benchmark.py | 41 +++++++++------ .../sglang/extended_attention_benchmark.py | 52 ++++++++++++------- .../sglang/prefill_attention_benchmark.py | 38 +++++++++----- 4 files changed, 90 insertions(+), 53 deletions(-) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 1a50a27438..f2ac9d1c77 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -120,8 +120,8 @@ jobs: run: | git clone https://github.com/sgl-project/sglang.git cd sglang - git apply ../benchmarks/third_party/sglang/sglang.patch - pip install ./python[dev_xpu] + git apply ../benchmarks/third_party/sglang/sglang-fix.patch + pip install "./python[dev_xpu]" # Reinstallation since SGLang installation will force overrides current PyTorch and Triton - name: Reinstall PyTorch @@ -139,7 +139,7 @@ jobs: python prefill_attention_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh - python ../../triton_kernels_benchmark/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-prefill-attn-performance.csv $REPORTS/sglang-prefill-attn-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - name: Run SGLANG attention decode stage benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} @@ -149,7 +149,7 @@ jobs: python decode_attention_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh - python ../../triton_kernels_benchmark/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-decode-attn-performance.csv $REPORTS/sglang-decode-attn-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - name: Run SGLANG attention append stage benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() }} @@ -159,10 +159,10 @@ jobs: python extended_attention_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh - python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-extended-attn-performance.csv $REPORTS/sglang-append-attn-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - name: Run SGLANG Block FP8 GEMM benchmark - if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'block_fp8_gemm_benchmark.py') }} + if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/third_party/sglang diff --git a/benchmarks/third_party/sglang/decode_attention_benchmark.py b/benchmarks/third_party/sglang/decode_attention_benchmark.py index b9edc2af80..15fe7f951b 100644 --- a/benchmarks/third_party/sglang/decode_attention_benchmark.py +++ b/benchmarks/third_party/sglang/decode_attention_benchmark.py @@ -43,18 +43,29 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): sm_scale) +def get_dtype(dtype_str: str): + if dtype_str == 'bfloat16': + return torch.bfloat16 + if dtype_str == 'float16': + return torch.float16 + if dtype_str == 'float32': + return torch.float32 + raise ValueError(f'Unsupported dtype: {dtype_str}') + + +X_VALS = [[bs, *sizes, mode, dtype] + for sizes in [(1024 + 64, 32, 8, 128), (1024 + 64, 32, 32, 96), (1024 + 64, 28, 4, 128)] + for bs in [1, 16, 32, 64, 128] + for mode in ['fwd'] + for dtype in ['bfloat16']] + + # pylint: disable=unused-argument @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE'], - x_vals=[ # - [bs, 1024 + 64, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128] - ] + [ # - [bs, 1024 + 64, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128] - ] + [ # - [bs, 1024 + 64, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128] - ], + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'DTYPE'], + x_vals=X_VALS, line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -68,19 +79,19 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel=['GB/s', 'TFlops'], # label name for the y-axis - plot_name='decode-attn-performance', + plot_name='sglang-decode-attn-performance', # name for the plot. Used also as a file name for saving the plot. args={}, )) -def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, provider): +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, DTYPE, provider): torch.manual_seed(0) - dtype = torch.bfloat16 - quantiles = [0.5, 0.0, 1.0] - N_CTX = SEQ_LENS + dtype = get_dtype(DTYPE) + N_CTX = SEQ_LENS q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args( B, N_CTX, H_Q, H_KV, D, dtype, 'xpu') + quantiles = [0.5, 0.0, 1.0] if provider == 'triton' and MODE == 'fwd': triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale) @@ -89,8 +100,8 @@ def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, provider): else: raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') - tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * (1e-12) / (ms * 1e-3) - gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3) + tflops = lambda ms: B * N_CTX * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3) + gbps = lambda ms: B * (H_Q + 2 * N_CTX * H_KV) * D * 2 * (1e-9) / (ms * 1e-3) return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv diff --git a/benchmarks/third_party/sglang/extended_attention_benchmark.py b/benchmarks/third_party/sglang/extended_attention_benchmark.py index 4624d4649b..cc37fd51c4 100644 --- a/benchmarks/third_party/sglang/extended_attention_benchmark.py +++ b/benchmarks/third_party/sglang/extended_attention_benchmark.py @@ -5,10 +5,10 @@ # pylint: disable=unused-argument -def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device): +def gen_args(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, dtype, device): b_seq_len_prefix = torch.full((B, ), PREFIX_LEN, dtype=torch.int32, device=device) - b_seq_len_extend = torch.full((B, ), Q_LEN, dtype=torch.int32, device=device) + b_seq_len_extend = torch.full((B, ), EXTEND_LEN, dtype=torch.int32, device=device) b_seq_len = b_seq_len_prefix + b_seq_len_extend b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device) @@ -55,18 +55,31 @@ def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device): return params +def get_dtype(dtype_str: str): + if dtype_str == 'bfloat16': + return torch.bfloat16 + if dtype_str == 'float16': + return torch.float16 + if dtype_str == 'float32': + return torch.float32 + raise ValueError(f'Unsupported dtype: {dtype_str}') + + +X_VALS = [[bs, *sizes, mode, dtype] + for sizes in [(512, 1024 + 128, 32, 8, 128), # + (512, 1024 + 128, 32, 32, 96), # + (512, 1024 + 128, 28, 4, 128)] + for bs in [1, 16, 32, 64, 128] + for mode in ['fwd'] + for dtype in ['bfloat16']] + + # pylint: disable=unused-argument @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'Q_LEN', 'PREFIX_LEN', 'KV_LEN', 'H_Q', 'H_KV', 'D', 'MODE'], - x_vals=[ # - [bs, 512, 1024 + 128, 512, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128] - ] + [ # - [bs, 512, 1024 + 128, 512, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128] - ] + [ # - [bs, 512, 1024 + 128, 512, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128] - ], + x_names=['B', 'EXTEND_LEN', 'PREFIX_LEN', 'H_Q', 'H_KV', 'D', 'MODE', 'DTYPE'], + x_vals=X_VALS, line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -80,16 +93,15 @@ def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device): # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel=['GB/s', 'TFlops'], # label name for the y-axis - plot_name='extended-attn-performance', + plot_name='sglang-extended-attn-performance', # name for the plot. Used also as a file name for saving the plot. args={}, )) -def benchmark(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, MODE, provider): +def benchmark(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, MODE, DTYPE, provider): torch.manual_seed(0) + dtype = get_dtype(DTYPE) - dtype = torch.bfloat16 - - params = gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, 'xpu') + params = gen_args(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, dtype, 'xpu') q_extend, k_extend, v_extend, o_extend = params[0] k_buffer, v_buffer = params[1] qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2] @@ -105,10 +117,12 @@ def benchmark(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, MODE, provider): else: raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') - N_CTX_TOTAL = k_buffer.shape[0] - N_CTX_EXTEND = k_extend.shape[0] - tflops = lambda ms: (H_Q + H_KV) * (N_CTX_EXTEND + N_CTX_TOTAL) * N_CTX_TOTAL * D * (1e-12) / (ms * 1e-3) - gbps = lambda ms: 2 * (N_CTX_EXTEND * (H_Q + H_KV) + N_CTX_TOTAL * H_KV) * D * 2 * (1e-9) / (ms * 1e-3) + N_CTX_TOTAL = PREFIX_LEN + EXTEND_LEN + N_CTX_EXTEND = EXTEND_LEN + + tflops = lambda ms: B * (N_CTX_EXTEND + N_CTX_TOTAL) * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3) + gbps = lambda ms: B * ((H_Q * N_CTX_EXTEND) + H_KV * + (N_CTX_EXTEND + N_CTX_TOTAL) * 2) * D * 2 * (1e-9) / (ms * 1e-3) return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv diff --git a/benchmarks/third_party/sglang/prefill_attention_benchmark.py b/benchmarks/third_party/sglang/prefill_attention_benchmark.py index 205a68c2f0..348d4c4ece 100644 --- a/benchmarks/third_party/sglang/prefill_attention_benchmark.py +++ b/benchmarks/third_party/sglang/prefill_attention_benchmark.py @@ -22,18 +22,30 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device): return (q, k, v, o, b_start_loc, b_seq_len, max_seq_len) +def get_dtype(dtype_str: str): + if dtype_str == 'bfloat16': + return torch.bfloat16 + if dtype_str == 'float16': + return torch.float16 + if dtype_str == 'float32': + return torch.float32 + raise ValueError(f'Unsupported dtype: {dtype_str}') + + +X_VALS = [[bs, *sizes, causal, mode, dtype] + for bs in [1, 16, 32, 64, 128] + for sizes in [(1024, 32, 8, 128), (1024, 32, 32, 96), (1024, 28, 4, 128)] + for causal in [True, False] + for mode in ['fwd'] + for dtype in ['bfloat16']] + + # pylint: disable=unused-argument @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE'], - x_vals=[ # - [bs, 1024, 32, 8, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128] - ] + [ # - [bs, 1024, 32, 32, 96, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128] - ] + [ # - [bs, 1024, 28, 4, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128] - ], + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE', 'DTYPE'], + x_vals=X_VALS, line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -47,13 +59,13 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device): # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel=['GB/s', 'TFlops'], # label name for the y-axis - plot_name='prefill-attn-performance', + plot_name='sglang-prefill-attn-performance', # name for the plot. Used also as a file name for saving the plot. args={}, )) -def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, provider): +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, DTYPE, provider): torch.manual_seed(0) - dtype = torch.bfloat16 + dtype = get_dtype(DTYPE) q, k, v, o, b_start_loc, b_seq_len, max_seq_len = gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, 'xpu') @@ -66,8 +78,8 @@ def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, provider): raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') N_CTX = SEQ_LENS - tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * N_CTX * D * (1e-12) / (ms * 1e-3) - gbps = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * 2 * (1e-9) / (ms * 1e-3) + tflops = lambda ms: B * N_CTX * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3) + gbps = lambda ms: B * N_CTX * (H_Q + 2 * H_KV) * D * 2 * (1e-9) / (ms * 1e-3) return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv From 6fd631204416010ee521174b7e848a3144cbdafb Mon Sep 17 00:00:00 2001 From: "Ling, Liyang" Date: Wed, 23 Jul 2025 01:41:04 +0000 Subject: [PATCH 5/5] Move sglang benchmarks to `triton_kernels_benchmark` folder --- .github/workflows/third-party-benchmarks.yml | 8 ++++---- .../block_fp8_gemm_benchmark.py | 0 .../decode_attention_benchmark.py | 0 .../extended_attention_benchmark.py | 0 .../prefill_attention_benchmark.py | 0 5 files changed, 4 insertions(+), 4 deletions(-) rename benchmarks/{third_party/sglang => triton_kernels_benchmark}/block_fp8_gemm_benchmark.py (100%) rename benchmarks/{third_party/sglang => triton_kernels_benchmark}/decode_attention_benchmark.py (100%) rename benchmarks/{third_party/sglang => triton_kernels_benchmark}/extended_attention_benchmark.py (100%) rename benchmarks/{third_party/sglang => triton_kernels_benchmark}/prefill_attention_benchmark.py (100%) diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index f2ac9d1c77..dda3b8b1f8 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -135,7 +135,7 @@ jobs: if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH - cd benchmarks/third_party/sglang + cd benchmarks/triton_kernels_benchmark python prefill_attention_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh @@ -145,7 +145,7 @@ jobs: if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH - cd benchmarks/third_party/sglang + cd benchmarks/triton_kernels_benchmark python decode_attention_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh @@ -155,7 +155,7 @@ jobs: if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH - cd benchmarks/third_party/sglang + cd benchmarks/triton_kernels_benchmark python extended_attention_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh @@ -165,7 +165,7 @@ jobs: if: ${{ steps.install.outcome == 'success' && !cancelled() }} run: | export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH - cd benchmarks/third_party/sglang + cd benchmarks/triton_kernels_benchmark python block_fp8_gemm_benchmark.py --reports $REPORTS source ../../../scripts/capture-hw-details.sh diff --git a/benchmarks/third_party/sglang/block_fp8_gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/block_fp8_gemm_benchmark.py similarity index 100% rename from benchmarks/third_party/sglang/block_fp8_gemm_benchmark.py rename to benchmarks/triton_kernels_benchmark/block_fp8_gemm_benchmark.py diff --git a/benchmarks/third_party/sglang/decode_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/decode_attention_benchmark.py similarity index 100% rename from benchmarks/third_party/sglang/decode_attention_benchmark.py rename to benchmarks/triton_kernels_benchmark/decode_attention_benchmark.py diff --git a/benchmarks/third_party/sglang/extended_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/extended_attention_benchmark.py similarity index 100% rename from benchmarks/third_party/sglang/extended_attention_benchmark.py rename to benchmarks/triton_kernels_benchmark/extended_attention_benchmark.py diff --git a/benchmarks/third_party/sglang/prefill_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/prefill_attention_benchmark.py similarity index 100% rename from benchmarks/third_party/sglang/prefill_attention_benchmark.py rename to benchmarks/triton_kernels_benchmark/prefill_attention_benchmark.py