|
6 | 6 | import triton_kernels_benchmark as benchmark_suit |
7 | 7 |
|
8 | 8 |
|
9 | | -def gen_args(BATCH, N_CTX, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, dtype, device): |
| 9 | +def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device): |
10 | 10 |
|
11 | | - b_seq_len_prefix = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device=device) |
12 | | - b_seq_len_extend = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device=device) |
| 11 | + b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device) |
| 12 | + b_seq_len_extend = torch.randint(1, N_CTX // 2, (B, ), dtype=torch.int32, device=device) |
13 | 13 | b_seq_len = b_seq_len_prefix + b_seq_len_extend |
14 | 14 | max_len_in_batch = torch.max(b_seq_len, 0)[0].item() |
15 | 15 |
|
16 | | - b_req_idx = torch.arange(BATCH, dtype=torch.int32, device=device) |
17 | | - b_start_loc = torch.zeros((BATCH, ), dtype=torch.int32, device=device) |
| 16 | + b_req_idx = torch.arange(B, dtype=torch.int32, device=device) |
| 17 | + b_start_loc = torch.zeros((B, ), dtype=torch.int32, device=device) |
18 | 18 | b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) |
19 | | - b_start_loc_extend = torch.zeros((BATCH, ), dtype=torch.int32, device=device) |
| 19 | + b_start_loc_extend = torch.zeros((B, ), dtype=torch.int32, device=device) |
20 | 20 | b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) |
21 | 21 |
|
22 | | - kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device) |
23 | | - kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_prefix[:BATCH], dim=0) |
| 22 | + kv_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device) |
| 23 | + kv_indptr[1:B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) |
24 | 24 | kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device=device) |
25 | 25 |
|
26 | | - for i in range(BATCH): |
| 26 | + for i in range(B): |
27 | 27 | kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]) |
28 | 28 |
|
29 | 29 | total_token_num = torch.sum(b_seq_len).item() |
30 | 30 | extend_token_num = torch.sum(b_seq_len_extend).item() |
31 | | - k_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, |
32 | | - device=device).normal_(mean=0.1, std=0.2) |
33 | | - v_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, |
34 | | - device=device).normal_(mean=0.1, std=0.2) |
35 | | - |
36 | | - k_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device) |
37 | | - v_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device) |
38 | | - q_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device) |
39 | | - for i in range(BATCH): |
| 31 | + k_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2) |
| 32 | + v_buffer = torch.empty((total_token_num, H_KV, D), dtype=dtype, device=device).normal_(mean=0.1, std=0.2) |
| 33 | + |
| 34 | + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) |
| 35 | + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) |
| 36 | + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) |
| 37 | + for i in range(B): |
40 | 38 | extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] |
41 | 39 | extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] |
42 | 40 | extend_start = b_start_loc_extend[i] |
43 | 41 | extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] |
44 | 42 | k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer] |
45 | 43 | v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer] |
46 | | - q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], Q_HEAD_NUM, HEAD_DIM), dtype=dtype, |
| 44 | + q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], H_Q, D), dtype=dtype, |
47 | 45 | device=device).normal_(mean=0.1, std=0.2) |
48 | 46 |
|
49 | | - o_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device) |
50 | | - o_redundant = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device) |
| 47 | + o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) |
| 48 | + o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) |
51 | 49 |
|
52 | 50 | b_seq_len_extend = b_seq_len - b_seq_len_prefix |
53 | 51 | max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() |
54 | | - qo_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device) |
55 | | - qo_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_extend[:BATCH], dim=0) |
| 52 | + qo_indptr = torch.zeros((B + 1, ), dtype=torch.int32, device=device) |
| 53 | + qo_indptr[1:B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) |
56 | 54 |
|
57 | 55 | params = [] |
58 | 56 | params.append((q_extend, k_extend, v_extend, o_extend, o_redundant)) |
@@ -127,8 +125,10 @@ def refer_fn(): |
127 | 125 | else: |
128 | 126 | raise NotImplementedError(f'Unsupported provider {provider}') |
129 | 127 |
|
130 | | - tflops = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * N_CTX * D * (1e-12) / (ms * 1e-3) |
131 | | - gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3) |
| 128 | + N_CTX_TOTAL = k_buffer.shape[0] |
| 129 | + N_CTX_EXTEND = k_extend.shape[0] |
| 130 | + tflops = lambda ms: (H_Q + H_KV) * (N_CTX_EXTEND + N_CTX_TOTAL) * N_CTX_TOTAL * D * (1e-12) / (ms * 1e-3) |
| 131 | + gbps = lambda ms: 2 * (N_CTX_EXTEND * (H_Q + H_KV) + N_CTX_TOTAL * H_KV) * D * 2 * (1e-9) / (ms * 1e-3) |
132 | 132 |
|
133 | 133 | return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv |
134 | 134 |
|
|
0 commit comments