diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 57eda075a4..e873c34e9f 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -76,12 +76,26 @@ jobs: - name: Setup Triton uses: ./.github/actions/setup-triton + with: + command: DEBUG=1 python setup.py bdist_wheel + + - name: Install Triton + run: | + pip install dist/*.whl - name: Install benchmark dependencies id: install run: | pip install transformers pandas pytest + cd benchmarks + pip install . + pip install intel-pti==0.12.2 + PTI_LIBS_DIR=$(python -c "import sysconfig; print(sysconfig.get_paths()['stdlib']+'/..')") + # the output should contain: `libpti.so`, `libpti_metrics.so.0.12.2` and `libpti_view.so.0.12.2` + ls $PTI_LIBS_DIR + echo "PTI_LIBS_DIR=$PTI_LIBS_DIR" >> $GITHUB_ENV + - name: Create reports dir run: | mkdir reports @@ -107,6 +121,61 @@ jobs: # Return the captured return code at the end exit "$RET_CODE" + - name: Install SGLANG + run: | + git clone https://github.com/sgl-project/sglang.git + cd sglang + git apply ../benchmarks/third_party/sglang/sglang-fix.patch + pip install "./python[dev_xpu]" + + # Reinstallation since SGLang installation will force overrides current PyTorch and Triton + - name: Reinstall PyTorch + uses: ./.github/actions/setup-pytorch + + - name: Reinstall Triton + run: | + pip install ./dist/*.whl + + - name: Run SGLANG attention prefill stage benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH + cd benchmarks/triton_kernels_benchmark + python prefill_attention_benchmark.py --reports $REPORTS + + source ../../../scripts/capture-hw-details.sh + python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-prefill-attn-performance.csv $REPORTS/sglang-prefill-attn-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + + - name: Run SGLANG attention decode stage benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH + cd benchmarks/triton_kernels_benchmark + python decode_attention_benchmark.py --reports $REPORTS + + source ../../../scripts/capture-hw-details.sh + python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-decode-attn-performance.csv $REPORTS/sglang-decode-attn-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + + - name: Run SGLANG attention append stage benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH + cd benchmarks/triton_kernels_benchmark + python extended_attention_benchmark.py --reports $REPORTS + + source ../../../scripts/capture-hw-details.sh + python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-extended-attn-performance.csv $REPORTS/sglang-append-attn-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + + - name: Run SGLANG Block FP8 GEMM benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() }} + run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH + cd benchmarks/triton_kernels_benchmark + python block_fp8_gemm_benchmark.py --reports $REPORTS + + source ../../../scripts/capture-hw-details.sh + python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-fp8-gemm-performance.csv $REPORTS/sglang-fp8-gemm-triton-report.csv --benchmark sglang-block-fp8-gemm --compiler triton --param_cols "M,N,K" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + - name: Run e2e Llama 3.1 flex attention performance benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'llama3-1')) }} run: | diff --git a/benchmarks/triton_kernels_benchmark/block_fp8_gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/block_fp8_gemm_benchmark.py new file mode 100644 index 0000000000..a99e4a0d82 --- /dev/null +++ b/benchmarks/triton_kernels_benchmark/block_fp8_gemm_benchmark.py @@ -0,0 +1,335 @@ +""" +Block FP8 Gemm benchmark +============================ +This benchmark is come from SGLang kernels. +https://github.com/sgl-project/sglang/blob/07f944631e747d7489fde1f11de93e503afa90ba/python/sglang/srt/layers/quantization/fp8_kernel.py#L375 +""" + +from typing import List + +import torch +import triton +import triton.language as tl + +import triton_kernels_benchmark as benchmark_suit + +DEVICE_NAME = torch.xpu.get_device_name() +DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + kernel = _w8a8_block_fp8_matmul + + kernel[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C + + +# For test +def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise quantization using native torch. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + """ + + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def has_enough_memory(x_val): + # x_val: (M, N, K) + M, N, K = x_val + # a: (M, K) float8_e4m3 + # b: (N, K) float8_e4m3 + # c: (M, N) bfloat16 + # pytorch reference: (M, N) float32 + required_memory = M * K * 1 + N * K * 1 + M * N * 2 * 2 + enough_memory = required_memory < DEVICE_TOTAL_MEMORY + if not enough_memory: + print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}") + return enough_memory + + +X_VALS = [[1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [ + [1, 13824, 5120], + [4, 12288, 4096], + [512, 8192, 8192], + [512, 8192, 32768], + [512, 32768, 8192], + [1024, 8192, 16384], + [1024, 8192, 28672], + [3072, 3072, 4096], + [4096, 8192, 16384], + [8192, 1024, 16384], + [8192, 4096, 16384], + [16384, 1024, 8192], + [16384, 4096, 8192], + [16384, 8192, 1024], + [16384, 8192, 4096], + [32768, 128, 4096], + [32768, 4096, 128], + [4096, 128, 4096], + [8, 128, 16384], + [8, 16384, 128], +] + +X_VALS = [x_val for x_val in X_VALS if has_enough_memory(x_val)] + + +# Benchmark Performance +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=["M", "N", "K"], + # different possible values for `x_name` + x_vals=X_VALS, + line_arg="provider", + # argument name whose value corresponds to a different line in the plot + line_vals=["triton"], + # label name for the lines + line_names=["Triton"], + # line styles + ylabel=["GB/s", "TFlops"], # label name for the y-axis + plot_name="sglang-fp8-gemm-performance", + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + torch.manual_seed(0) + + block_size = [128, 128] + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale + + quantiles = [0.5, 0.0, 1.0] + + if provider == "triton": + triton_fn = lambda: w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size) + torch_fn = lambda: native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size) + benchmark_suit.assert_close(triton_fn, torch_fn, atol=3e-4, rtol=1e-2, err_msg="triton to torch") + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles) + + else: + raise NotImplementedError(f"Unsupported provider {provider}") + + tflops = lambda ms: 2 * M * N * K * (1e-12) / (ms * 1e-3) + gbps = lambda ms: (M * K + K * N) + 2.0 * (M * N) * (1e-9) / (ms * 1e-3) + + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == "__main__": + benchmark.run(show_plots=False, print_data=True) diff --git a/benchmarks/triton_kernels_benchmark/decode_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/decode_attention_benchmark.py new file mode 100644 index 0000000000..15fe7f951b --- /dev/null +++ b/benchmarks/triton_kernels_benchmark/decode_attention_benchmark.py @@ -0,0 +1,110 @@ +import torch + +from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd + +import triton_kernels_benchmark as benchmark_suit + + +def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): + + total_tokens = B * N_CTX + sm_scale = 1.0 / (D**0.5) + max_kv_splits = 8 + num_kv_splits = torch.full((B, ), 4, dtype=torch.int32, device=device) + + # q represents the new token being generated, one per B + q = torch.randn(B, H_Q, D, dtype=dtype, device=device) + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) + v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D, dtype=dtype, device=device) + + b_seq_len = torch.full((B, ), N_CTX, device=device) + + kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device) + kv_indptr[1:B + 1] = torch.cumsum(b_seq_len[:B], dim=0) + kv_indices = torch.arange(total_tokens, device=device) + + attn_logits = torch.empty( + (B, H_Q, max_kv_splits, D), + dtype=torch.float32, + device=device, + ) + attn_lse = torch.empty( + (B, H_Q, max_kv_splits), + dtype=torch.float32, + device=device, + ) + + return (q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, + sm_scale) + + +def get_dtype(dtype_str: str): + if dtype_str == 'bfloat16': + return torch.bfloat16 + if dtype_str == 'float16': + return torch.float16 + if dtype_str == 'float32': + return torch.float32 + raise ValueError(f'Unsupported dtype: {dtype_str}') + + +X_VALS = [[bs, *sizes, mode, dtype] + for sizes in [(1024 + 64, 32, 8, 128), (1024 + 64, 32, 32, 96), (1024 + 64, 28, 4, 128)] + for bs in [1, 16, 32, 64, 128] + for mode in ['fwd'] + for dtype in ['bfloat16']] + + +# pylint: disable=unused-argument +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'DTYPE'], + x_vals=X_VALS, + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=[ + 'triton', + ], + # label name for the lines + line_names=[ + 'Triton', + ], + # line styles + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], + ylabel=['GB/s', 'TFlops'], # label name for the y-axis + plot_name='sglang-decode-attn-performance', + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, DTYPE, provider): + torch.manual_seed(0) + dtype = get_dtype(DTYPE) + + N_CTX = SEQ_LENS + q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args( + B, N_CTX, H_Q, H_KV, D, dtype, 'xpu') + + quantiles = [0.5, 0.0, 1.0] + if provider == 'triton' and MODE == 'fwd': + triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, + num_kv_splits, max_kv_splits, sm_scale) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) + + else: + raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') + + tflops = lambda ms: B * N_CTX * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3) + gbps = lambda ms: B * (H_Q + 2 * N_CTX * H_KV) * D * 2 * (1e-9) / (ms * 1e-3) + + return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == '__main__': + benchmark.run(show_plots=False, print_data=True) diff --git a/benchmarks/triton_kernels_benchmark/extended_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/extended_attention_benchmark.py new file mode 100644 index 0000000000..cc37fd51c4 --- /dev/null +++ b/benchmarks/triton_kernels_benchmark/extended_attention_benchmark.py @@ -0,0 +1,131 @@ +import torch +from sglang.srt.layers.attention.triton_ops.extend_attention import ( + extend_attention_fwd, ) +import triton_kernels_benchmark as benchmark_suit + + +# pylint: disable=unused-argument +def gen_args(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, dtype, device): + + b_seq_len_prefix = torch.full((B, ), PREFIX_LEN, dtype=torch.int32, device=device) + b_seq_len_extend = torch.full((B, ), EXTEND_LEN, dtype=torch.int32, device=device) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + + b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B, ), dtype=torch.int32, device=device) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device) + kv_indptr[1:B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device=device) + + for i in range(B): + kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer] + v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer] + q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], H_Q, D), dtype=dtype, + device=device).normal_(mean=0.1, std=0.2) + + o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + qo_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device) + qo_indptr[1:B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + + params = [] + params.append((q_extend, k_extend, v_extend, o_extend)) + params.append((k_buffer, v_buffer)) + params.append((qo_indptr, kv_indptr, kv_indices, max_len_extend)) + return params + + +def get_dtype(dtype_str: str): + if dtype_str == 'bfloat16': + return torch.bfloat16 + if dtype_str == 'float16': + return torch.float16 + if dtype_str == 'float32': + return torch.float32 + raise ValueError(f'Unsupported dtype: {dtype_str}') + + +X_VALS = [[bs, *sizes, mode, dtype] + for sizes in [(512, 1024 + 128, 32, 8, 128), # + (512, 1024 + 128, 32, 32, 96), # + (512, 1024 + 128, 28, 4, 128)] + for bs in [1, 16, 32, 64, 128] + for mode in ['fwd'] + for dtype in ['bfloat16']] + + +# pylint: disable=unused-argument +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['B', 'EXTEND_LEN', 'PREFIX_LEN', 'H_Q', 'H_KV', 'D', 'MODE', 'DTYPE'], + x_vals=X_VALS, + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=[ + 'triton', + ], + # label name for the lines + line_names=[ + 'Triton', + ], + # line styles + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], + ylabel=['GB/s', 'TFlops'], # label name for the y-axis + plot_name='sglang-extended-attn-performance', + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, MODE, DTYPE, provider): + torch.manual_seed(0) + dtype = get_dtype(DTYPE) + + params = gen_args(B, EXTEND_LEN, PREFIX_LEN, H_Q, H_KV, D, dtype, 'xpu') + q_extend, k_extend, v_extend, o_extend = params[0] + k_buffer, v_buffer = params[1] + qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2] + custom_mask = None + mask_indptr = None + + quantiles = [0.5, 0.0, 1.0] + if provider == 'triton' and MODE == 'fwd': + triton_fn = lambda: extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, + kv_indptr, kv_indices, custom_mask, True, mask_indptr, max_len_extend) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) + + else: + raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') + + N_CTX_TOTAL = PREFIX_LEN + EXTEND_LEN + N_CTX_EXTEND = EXTEND_LEN + + tflops = lambda ms: B * (N_CTX_EXTEND + N_CTX_TOTAL) * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3) + gbps = lambda ms: B * ((H_Q * N_CTX_EXTEND) + H_KV * + (N_CTX_EXTEND + N_CTX_TOTAL) * 2) * D * 2 * (1e-9) / (ms * 1e-3) + + return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == '__main__': + benchmark.run(show_plots=False, print_data=True) diff --git a/benchmarks/triton_kernels_benchmark/prefill_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/prefill_attention_benchmark.py new file mode 100644 index 0000000000..348d4c4ece --- /dev/null +++ b/benchmarks/triton_kernels_benchmark/prefill_attention_benchmark.py @@ -0,0 +1,88 @@ +import torch + +from sglang.srt.layers.attention.triton_ops.prefill_attention import context_attention_fwd + +import triton_kernels_benchmark as benchmark_suit + + +def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device): + max_seq_len = SEQ_LENS + N_CTX = SEQ_LENS + + # Create random input tensors + q = torch.randn((B * N_CTX, H_Q, D), device=device, dtype=dtype) + k = torch.randn((B * N_CTX, H_KV, D), device=device, dtype=dtype) + v = torch.randn((B * N_CTX, H_KV, D), device=device, dtype=dtype) + o = torch.zeros((B * N_CTX, H_Q, D), device=device, dtype=dtype) + + # Create b_start_loc and b_seq_len tensors + b_start_loc = torch.tensor([0, SEQ_LENS], device=device) + b_seq_len = torch.tensor([SEQ_LENS], device=device) + + return (q, k, v, o, b_start_loc, b_seq_len, max_seq_len) + + +def get_dtype(dtype_str: str): + if dtype_str == 'bfloat16': + return torch.bfloat16 + if dtype_str == 'float16': + return torch.float16 + if dtype_str == 'float32': + return torch.float32 + raise ValueError(f'Unsupported dtype: {dtype_str}') + + +X_VALS = [[bs, *sizes, causal, mode, dtype] + for bs in [1, 16, 32, 64, 128] + for sizes in [(1024, 32, 8, 128), (1024, 32, 32, 96), (1024, 28, 4, 128)] + for causal in [True, False] + for mode in ['fwd'] + for dtype in ['bfloat16']] + + +# pylint: disable=unused-argument +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE', 'DTYPE'], + x_vals=X_VALS, + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=[ + 'triton', + ], + # label name for the lines + line_names=[ + 'Triton', + ], + # line styles + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], + ylabel=['GB/s', 'TFlops'], # label name for the y-axis + plot_name='sglang-prefill-attn-performance', + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, DTYPE, provider): + torch.manual_seed(0) + dtype = get_dtype(DTYPE) + + q, k, v, o, b_start_loc, b_seq_len, max_seq_len = gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, 'xpu') + + quantiles = [0.5, 0.0, 1.0] + if provider == 'triton' and MODE == 'fwd': + triton_fn = lambda: context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=CAUSAL) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles) + else: + raise NotImplementedError(f'Unsupported provider {provider} and mode {MODE}') + + N_CTX = SEQ_LENS + tflops = lambda ms: B * N_CTX * H_Q * D * H_KV * 2 * 2 * (1e-12) / (ms * 1e-3) + gbps = lambda ms: B * N_CTX * (H_Q + 2 * H_KV) * D * 2 * (1e-9) / (ms * 1e-3) + + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv + + +if __name__ == '__main__': + benchmark.run(show_plots=False, print_data=True)