|
| 1 | +""" |
| 2 | +Block FP8 Gemm benchmark |
| 3 | +============================ |
| 4 | +This benchmark is come from SGLang kernels. |
| 5 | +https://github.com/sgl-project/sglang/blob/07f944631e747d7489fde1f11de93e503afa90ba/python/sglang/srt/layers/quantization/fp8_kernel.py#L375 |
| 6 | +""" |
| 7 | + |
| 8 | +from typing import List |
| 9 | + |
| 10 | +import torch |
| 11 | +import triton |
| 12 | +import triton.language as tl |
| 13 | + |
| 14 | +import triton_kernels_benchmark as benchmark_suit |
| 15 | + |
| 16 | +DEVICE_NAME = torch.xpu.get_device_name() |
| 17 | +DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory |
| 18 | + |
| 19 | + |
| 20 | +@triton.jit |
| 21 | +def _w8a8_block_fp8_matmul( |
| 22 | + # Pointers to inputs and output |
| 23 | + A, |
| 24 | + B, |
| 25 | + C, |
| 26 | + As, |
| 27 | + Bs, |
| 28 | + # Shape for matmul |
| 29 | + M, |
| 30 | + N, |
| 31 | + K, |
| 32 | + # Block size for block-wise quantization |
| 33 | + group_n, |
| 34 | + group_k, |
| 35 | + # Stride for inputs and output |
| 36 | + stride_am, |
| 37 | + stride_ak, |
| 38 | + stride_bk, |
| 39 | + stride_bn, |
| 40 | + stride_cm, |
| 41 | + stride_cn, |
| 42 | + stride_As_m, |
| 43 | + stride_As_k, |
| 44 | + stride_Bs_k, |
| 45 | + stride_Bs_n, |
| 46 | + # Meta-parameters |
| 47 | + BLOCK_SIZE_M: tl.constexpr, |
| 48 | + BLOCK_SIZE_N: tl.constexpr, |
| 49 | + BLOCK_SIZE_K: tl.constexpr, |
| 50 | + GROUP_SIZE_M: tl.constexpr, |
| 51 | +): |
| 52 | + """Triton-accelerated function used to perform linear operations (dot |
| 53 | + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output |
| 54 | + tensor `C`. |
| 55 | + """ |
| 56 | + |
| 57 | + pid = tl.program_id(axis=0) |
| 58 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 59 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 60 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 61 | + group_id = pid // num_pid_in_group |
| 62 | + first_pid_m = group_id * GROUP_SIZE_M |
| 63 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 64 | + pid_m = first_pid_m + (pid % group_size_m) |
| 65 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 66 | + |
| 67 | + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M |
| 68 | + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| 69 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 70 | + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 71 | + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
| 72 | + |
| 73 | + As_ptrs = As + offs_am * stride_As_m |
| 74 | + offs_bsn = offs_bn // group_n |
| 75 | + Bs_ptrs = Bs + offs_bsn * stride_Bs_n |
| 76 | + |
| 77 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 78 | + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| 79 | + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) |
| 80 | + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
| 81 | + |
| 82 | + k_start = k * BLOCK_SIZE_K |
| 83 | + offs_ks = k_start // group_k |
| 84 | + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) |
| 85 | + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) |
| 86 | + |
| 87 | + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] |
| 88 | + a_ptrs += BLOCK_SIZE_K * stride_ak |
| 89 | + b_ptrs += BLOCK_SIZE_K * stride_bk |
| 90 | + |
| 91 | + if C.dtype.element_ty == tl.bfloat16: |
| 92 | + c = accumulator.to(tl.bfloat16) |
| 93 | + elif C.dtype.element_ty == tl.float16: |
| 94 | + c = accumulator.to(tl.float16) |
| 95 | + else: |
| 96 | + c = accumulator.to(tl.float32) |
| 97 | + |
| 98 | + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 99 | + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 100 | + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] |
| 101 | + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
| 102 | + tl.store(c_ptrs, c, mask=c_mask) |
| 103 | + |
| 104 | + |
| 105 | +def w8a8_block_fp8_matmul( |
| 106 | + A: torch.Tensor, |
| 107 | + B: torch.Tensor, |
| 108 | + As: torch.Tensor, |
| 109 | + Bs: torch.Tensor, |
| 110 | + block_size: List[int], |
| 111 | + output_dtype: torch.dtype = torch.float16, |
| 112 | +) -> torch.Tensor: |
| 113 | + """This function performs matrix multiplication with block-wise quantization. |
| 114 | + It takes two input tensors `A` and `B` with scales `As` and `Bs`. |
| 115 | + The output is returned in the specified `output_dtype`. |
| 116 | + Args: |
| 117 | + A: The input tensor, e.g., activation. |
| 118 | + B: The input tensor, e.g., weight. |
| 119 | + As: The per-token-group quantization scale for `A`. |
| 120 | + Bs: The per-block quantization scale for `B`. |
| 121 | + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. |
| 122 | + output_dytpe: The dtype of the returned tensor. |
| 123 | + Returns: |
| 124 | + torch.Tensor: The result of matmul. |
| 125 | + """ |
| 126 | + assert len(block_size) == 2 |
| 127 | + block_n, block_k = block_size[0], block_size[1] |
| 128 | + |
| 129 | + assert A.shape[-1] == B.shape[-1] |
| 130 | + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() |
| 131 | + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] |
| 132 | + M = A.numel() // A.shape[-1] |
| 133 | + |
| 134 | + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 |
| 135 | + N, K = B.shape |
| 136 | + assert triton.cdiv(N, block_n) == Bs.shape[0] |
| 137 | + assert triton.cdiv(K, block_k) == Bs.shape[1] |
| 138 | + |
| 139 | + C_shape = A.shape[:-1] + (N, ) |
| 140 | + C = A.new_empty(C_shape, dtype=output_dtype) |
| 141 | + |
| 142 | + # Default config |
| 143 | + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] |
| 144 | + config = { |
| 145 | + "BLOCK_SIZE_M": 64, |
| 146 | + "BLOCK_SIZE_N": block_size[0], |
| 147 | + "BLOCK_SIZE_K": block_size[1], |
| 148 | + "GROUP_SIZE_M": 32, |
| 149 | + "num_warps": 4, |
| 150 | + "num_stages": 3, |
| 151 | + } |
| 152 | + |
| 153 | + def grid(META): |
| 154 | + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) |
| 155 | + |
| 156 | + kernel = _w8a8_block_fp8_matmul |
| 157 | + |
| 158 | + kernel[grid]( |
| 159 | + A, |
| 160 | + B, |
| 161 | + C, |
| 162 | + As, |
| 163 | + Bs, |
| 164 | + M, |
| 165 | + N, |
| 166 | + K, |
| 167 | + block_n, |
| 168 | + block_k, |
| 169 | + A.stride(-2), |
| 170 | + A.stride(-1), |
| 171 | + B.stride(1), |
| 172 | + B.stride(0), |
| 173 | + C.stride(-2), |
| 174 | + C.stride(-1), |
| 175 | + As.stride(-2), |
| 176 | + As.stride(-1), |
| 177 | + Bs.stride(1), |
| 178 | + Bs.stride(0), |
| 179 | + **config, |
| 180 | + ) |
| 181 | + |
| 182 | + return C |
| 183 | + |
| 184 | + |
| 185 | +# For test |
| 186 | +def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): |
| 187 | + """This function performs matrix multiplication with block-wise quantization using native torch. |
| 188 | +
|
| 189 | + It takes two input tensors `A` and `B` with scales `As` and `Bs`. |
| 190 | + The output is returned in the specified `output_dtype`. |
| 191 | + """ |
| 192 | + |
| 193 | + A = A.to(torch.float32) |
| 194 | + B = B.to(torch.float32) |
| 195 | + assert A.shape[-1] == B.shape[-1] |
| 196 | + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 |
| 197 | + assert len(block_size) == 2 |
| 198 | + block_n, block_k = block_size[0], block_size[1] |
| 199 | + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] |
| 200 | + assert A.shape[:-1] == As.shape[:-1] |
| 201 | + |
| 202 | + M = A.numel() // A.shape[-1] |
| 203 | + N, K = B.shape |
| 204 | + origin_C_shape = A.shape[:-1] + (N, ) |
| 205 | + A = A.reshape(M, A.shape[-1]) |
| 206 | + As = As.reshape(M, As.shape[-1]) |
| 207 | + n_tiles = (N + block_n - 1) // block_n |
| 208 | + k_tiles = (K + block_k - 1) // block_k |
| 209 | + assert n_tiles == Bs.shape[0] |
| 210 | + assert k_tiles == Bs.shape[1] |
| 211 | + |
| 212 | + C_shape = (M, N) |
| 213 | + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) |
| 214 | + |
| 215 | + A_tiles = [A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)] |
| 216 | + B_tiles = [[ |
| 217 | + B[ |
| 218 | + j * block_n:min((j + 1) * block_n, N), |
| 219 | + i * block_k:min((i + 1) * block_k, K), |
| 220 | + ] for i in range(k_tiles) |
| 221 | + ] for j in range(n_tiles)] |
| 222 | + C_tiles = [C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)] |
| 223 | + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] |
| 224 | + |
| 225 | + for i in range(k_tiles): |
| 226 | + for j in range(n_tiles): |
| 227 | + a = A_tiles[i] |
| 228 | + b = B_tiles[j][i] |
| 229 | + c = C_tiles[j] |
| 230 | + s = As_tiles[i] * Bs[j][i] |
| 231 | + c[:, :] += torch.matmul(a, b.t()) * s |
| 232 | + |
| 233 | + C = C.reshape(origin_C_shape).to(output_dtype) |
| 234 | + return C |
| 235 | + |
| 236 | + |
| 237 | +def has_enough_memory(x_val): |
| 238 | + # x_val: (M, N, K) |
| 239 | + M, N, K = x_val |
| 240 | + # a: (M, K) float8_e4m3 |
| 241 | + # b: (N, K) float8_e4m3 |
| 242 | + # c: (M, N) bfloat16 |
| 243 | + # pytorch reference: (M, N) float32 |
| 244 | + required_memory = M * K * 1 + N * K * 1 + M * N * 2 * 2 |
| 245 | + enough_memory = required_memory < DEVICE_TOTAL_MEMORY |
| 246 | + if not enough_memory: |
| 247 | + print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}") |
| 248 | + return enough_memory |
| 249 | + |
| 250 | + |
| 251 | +X_VALS = [[1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [ |
| 252 | + [1, 13824, 5120], |
| 253 | + [4, 12288, 4096], |
| 254 | + [512, 8192, 8192], |
| 255 | + [512, 8192, 32768], |
| 256 | + [512, 32768, 8192], |
| 257 | + [1024, 8192, 16384], |
| 258 | + [1024, 8192, 28672], |
| 259 | + [3072, 3072, 4096], |
| 260 | + [4096, 8192, 16384], |
| 261 | + [8192, 1024, 16384], |
| 262 | + [8192, 4096, 16384], |
| 263 | + [16384, 1024, 8192], |
| 264 | + [16384, 4096, 8192], |
| 265 | + [16384, 8192, 1024], |
| 266 | + [16384, 8192, 4096], |
| 267 | + [32768, 128, 4096], |
| 268 | + [32768, 4096, 128], |
| 269 | + [4096, 128, 4096], |
| 270 | + [8, 128, 16384], |
| 271 | + [8, 16384, 128], |
| 272 | +] |
| 273 | + |
| 274 | +X_VALS = [x_val for x_val in X_VALS if has_enough_memory(x_val)] |
| 275 | + |
| 276 | + |
| 277 | +# Benchmark Performance |
| 278 | +@benchmark_suit.perf_report( |
| 279 | + benchmark_suit.Benchmark( |
| 280 | + # argument names to use as an x-axis for the plot |
| 281 | + x_names=["M", "N", "K"], |
| 282 | + # different possible values for `x_name` |
| 283 | + x_vals=X_VALS, |
| 284 | + line_arg="provider", |
| 285 | + # argument name whose value corresponds to a different line in the plot |
| 286 | + line_vals=["triton"], |
| 287 | + # label name for the lines |
| 288 | + line_names=["Triton"], |
| 289 | + # line styles |
| 290 | + ylabel=["GB/s", "TFlops"], # label name for the y-axis |
| 291 | + plot_name="sglang-fp8-gemm-performance", |
| 292 | + # name for the plot. Used also as a file name for saving the plot. |
| 293 | + args={}, |
| 294 | + )) |
| 295 | +def benchmark(M, N, K, provider): |
| 296 | + torch.manual_seed(0) |
| 297 | + |
| 298 | + block_size = [128, 128] |
| 299 | + factor_for_scale = 1e-2 |
| 300 | + fp8_info = torch.finfo(torch.float8_e4m3fn) |
| 301 | + fp8_max, fp8_min = fp8_info.max, fp8_info.min |
| 302 | + |
| 303 | + A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max |
| 304 | + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 305 | + |
| 306 | + B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max |
| 307 | + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 308 | + |
| 309 | + block_n, block_k = block_size[0], block_size[1] |
| 310 | + n_tiles = (N + block_n - 1) // block_n |
| 311 | + k_tiles = (K + block_k - 1) // block_k |
| 312 | + |
| 313 | + As = torch.rand(M, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale |
| 314 | + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale |
| 315 | + |
| 316 | + quantiles = [0.5, 0.0, 1.0] |
| 317 | + |
| 318 | + if provider == "triton": |
| 319 | + triton_fn = lambda: w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size) |
| 320 | + torch_fn = lambda: native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size) |
| 321 | + benchmark_suit.assert_close(triton_fn, torch_fn, atol=3e-4, rtol=1e-2, err_msg="triton to torch") |
| 322 | + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, |
| 323 | + quantiles=quantiles) |
| 324 | + |
| 325 | + else: |
| 326 | + raise NotImplementedError(f"Unsupported provider {provider}") |
| 327 | + |
| 328 | + tflops = lambda ms: 2 * M * N * K * (1e-12) / (ms * 1e-3) |
| 329 | + gbps = lambda ms: (M * K + K * N) + 2.0 * (M * N) * (1e-9) / (ms * 1e-3) |
| 330 | + |
| 331 | + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv |
| 332 | + |
| 333 | + |
| 334 | +if __name__ == "__main__": |
| 335 | + benchmark.run(show_plots=False, print_data=True) |
0 commit comments