Skip to content

Commit 25f5666

Browse files
authored
Create a GEMM benchmark using tensor descriptors (#3858)
This PR introduces a new GEMM benchmark where tensor descriptors are used rather than block ptrs. --------- Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
1 parent 8a9787f commit 25f5666

File tree

3 files changed

+356
-1
lines changed

3 files changed

+356
-1
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ jobs:
164164
python ../../scripts/build_report.py $REPORTS/matmul-tensor-of-ptr-performance.csv $REPORTS/gemm-tensor-of-ptr-xetla-report.csv --benchmark gemm-tensor-of-ptr --compiler xetla --param_cols "B,M,K,N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
165165
python ../../scripts/build_report.py $REPORTS/matmul-tensor-of-ptr-performance.csv $REPORTS/gemm-tensor-of-ptr-onednn-report.csv --benchmark gemm-tensor-of-ptr --compiler onednn --param_cols "B,M,K,N" --tflops_col OneDNN-TFlops --hbm_col "OneDNN-GB/s" --tag $TAG
166166
167+
- name: Run Triton GEMM kernel benchmark - with tensor descriptor
168+
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_tensor_desc_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_tensor_desc_benchmark.py') }}
169+
run: |
170+
cd benchmarks/triton_kernels_benchmark
171+
python gemm_tensor_desc_benchmark.py --reports $REPORTS --n_runs $N_RUNS
172+
source ../../scripts/capture-hw-details.sh
173+
python ../../scripts/build_report.py $REPORTS/matmul-tensor-desc-performance.csv $REPORTS/gemm-tensor-desc-triton-report.csv --benchmark gemm-tensor-desc --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
174+
python ../../scripts/build_report.py $REPORTS/matmul-tensor-desc-performance.csv $REPORTS/gemm-tensor-desc-xetla-report.csv --benchmark gemm-tensor-desc --compiler xetla --param_cols "B,M,K,N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
175+
python ../../scripts/build_report.py $REPORTS/matmul-tensor-desc-performance.csv $REPORTS/gemm-tensor-desc-onednn-report.csv --benchmark gemm-tensor-desc --compiler onednn --param_cols "B,M,K,N" --tflops_col OneDNN-TFlops --hbm_col "OneDNN-GB/s" --tag $TAG
176+
167177
- name: Run Triton GEMM (A@B^t) kernel benchmark
168178
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py_abt')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py_abt') }}
169179
run: |
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
"""
2+
Gemm benchmark (tensor descriptor)
3+
============================
4+
5+
This benchmark uses tensor descriptors to implement a GEMM kernel.
6+
To compare the performance to XeTLA kernel.
7+
8+
"""
9+
import os
10+
11+
import torch
12+
import triton
13+
import triton.language as tl
14+
15+
import triton_kernels_benchmark as benchmark_suit
16+
from triton_kernels_benchmark import xetla_kernel
17+
18+
TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1'
19+
TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1'
20+
use_xetla = not (TRANSPOSE_A or TRANSPOSE_B)
21+
22+
23+
@triton.autotune(
24+
configs=[
25+
triton.Config(
26+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
27+
num_stages=s, num_warps=32) for s in [2]
28+
] + [
29+
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
30+
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('large', 32), ('small', 64)])
31+
],
32+
key=['M', 'N', 'K'],
33+
)
34+
@triton.jit
35+
def matmul_kernel_with_tensor_descriptors(
36+
# Pointers to matrices
37+
a_ptr, b_ptr, c_ptr,
38+
# Matrix dimensions
39+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
40+
# Stride variables
41+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
42+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
43+
stride_cm: tl.constexpr, stride_cn: tl.constexpr,
44+
# Meta-parameters
45+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
46+
pid = tl.program_id(axis=0)
47+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
48+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
49+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
50+
group_id = pid // num_pid_in_group
51+
first_pid_m = group_id * GROUP_SIZE_M
52+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
53+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
54+
pid_n = (pid % num_pid_in_group) // group_size_m
55+
56+
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
57+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
58+
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
59+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
60+
61+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
62+
off_k = 0
63+
for _ in range(0, K, BLOCK_SIZE_K):
64+
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
65+
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
66+
accumulator += tl.dot(a, b)
67+
off_k += BLOCK_SIZE_K
68+
c = accumulator.to(tl.float32)
69+
70+
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
71+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
72+
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
73+
74+
75+
# pylint: disable=unused-argument
76+
@triton.autotune(
77+
configs=[
78+
triton.Config(
79+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
80+
num_stages=s, num_warps=32) for s in [2, 3]
81+
] + [
82+
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
83+
num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('large', 32), ('small', 64)])
84+
] + [
85+
triton.Config(
86+
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
87+
num_stages=s, num_warps=32) for s in [2, 3]
88+
] + [
89+
triton.Config(
90+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
91+
num_stages=s, num_warps=32) for s in [2]
92+
] + [
93+
triton.Config(
94+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
95+
num_stages=s, num_warps=32) for s in [2]
96+
] + [
97+
triton.Config(
98+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
99+
num_stages=s, num_warps=4) for s in [2]
100+
],
101+
key=['M', 'N', 'K'],
102+
)
103+
@triton.jit
104+
def matmul_kernel_with_tensor_descriptors_batched(
105+
# Pointers to matrices
106+
a_ptr, b_ptr, c_ptr,
107+
# Matrix dimensions
108+
B: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
109+
# Stride variables
110+
stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, #
111+
stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
112+
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
113+
# Meta-parameters
114+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
115+
bid = tl.program_id(axis=1)
116+
pid = tl.program_id(axis=0)
117+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
118+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
119+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
120+
group_id = pid // num_pid_in_group
121+
first_pid_m = group_id * GROUP_SIZE_M
122+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
123+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
124+
pid_n = (pid % num_pid_in_group) // group_size_m
125+
126+
offset_a = bid.to(tl.int64) * stride_az
127+
offset_b = bid.to(tl.int64) * stride_bz
128+
129+
a_desc = tl.make_tensor_descriptor(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak),
130+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
131+
b_desc = tl.make_tensor_descriptor(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn),
132+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
133+
134+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
135+
off_k = 0
136+
for _ in range(0, K, BLOCK_SIZE_K):
137+
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
138+
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
139+
accumulator += tl.dot(a, b)
140+
off_k += BLOCK_SIZE_K
141+
c = accumulator.to(tl.float32)
142+
143+
offset_c = bid.to(tl.int64) * stride_cz
144+
c_desc = tl.make_tensor_descriptor(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn),
145+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
146+
147+
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
148+
149+
150+
# We can now create a convenience wrapper function that only takes two input tensors,
151+
# and (1) checks any shape constraint; (2) launches the above kernel.
152+
def matmul(a, b, c, transpose_a=False, transpose_b=False):
153+
a_major, a_minor = -2, -1
154+
if transpose_a:
155+
a_major, a_minor = a_minor, a_major
156+
b_minor, b_major = -2, -1
157+
if transpose_b:
158+
b_major, b_minor = b_minor, b_major
159+
160+
assert a.shape[a_minor] == b.shape[b_minor], 'Incompatible dimensions'
161+
assert a.is_contiguous(), 'Matrix A must be contiguous'
162+
assert b.is_contiguous(), 'Matrix B must be contiguous'
163+
M, N, K = a.shape[a_major], b.shape[b_major], a.shape[a_minor]
164+
# Check constraints.
165+
if len(a.shape) == 3 and len(b.shape) == 3:
166+
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
167+
B = a.shape[0]
168+
# 1D launch kernel where each block gets its own program.
169+
grid = lambda META: (
170+
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
171+
B,
172+
)
173+
matmul_kernel_with_tensor_descriptors_batched[grid](
174+
a, b, c, #
175+
B, M, N, K, #
176+
a.stride(0), a.stride(a_major), a.stride(a_minor), #
177+
b.stride(0), b.stride(b_minor), b.stride(b_major), #
178+
c.stride(0), c.stride(1), c.stride(2))
179+
elif len(a.shape) == 2 and len(b.shape) == 2:
180+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
181+
matmul_kernel_with_tensor_descriptors[grid](
182+
a, b, c, #
183+
M, N, K, #
184+
a.stride(a_major), a.stride(a_minor), #
185+
b.stride(b_minor), b.stride(b_major), #
186+
c.stride(0), c.stride(1))
187+
else:
188+
assert False, 'Input matrices dimensions mismatch'
189+
return c
190+
191+
192+
def get_shapes(B, M, N, K, transpose_a, transpose_b):
193+
a_shape = (M, K)
194+
if transpose_a:
195+
a_shape = (K, M)
196+
197+
b_shape = (K, N)
198+
if transpose_b:
199+
b_shape = (N, K)
200+
201+
if B != 1:
202+
a_shape = (B, *a_shape)
203+
b_shape = (B, *b_shape)
204+
return a_shape, b_shape
205+
206+
207+
X_VALS = [[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [
208+
[1, 1, 13824, 5120],
209+
[1, 4, 12288, 4096],
210+
[1, 512, 8192, 8192],
211+
[1, 512, 8192, 32768],
212+
[1, 512, 32768, 8192],
213+
[1, 1024, 8192, 16384],
214+
[1, 1024, 8192, 28672],
215+
[1, 3072, 3072, 4096], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
216+
[1, 4096, 8192, 16384],
217+
[1, 8192, 1024, 16384],
218+
[1, 8192, 4096, 16384],
219+
[1, 16384, 1024, 8192],
220+
[1, 16384, 4096, 8192],
221+
[1, 16384, 8192, 1024],
222+
[1, 16384, 8192, 4096],
223+
[4, 32768, 128, 4096],
224+
[4, 32768, 4096, 128],
225+
[32, 4096, 128, 4096],
226+
[4096, 8, 128, 16384],
227+
[4096, 8, 16384, 128],
228+
]
229+
DEVICE_NAME = torch.xpu.get_device_name()
230+
DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory
231+
232+
233+
def is_enough_memory(x_val):
234+
# x_val: (B, M, N, K)
235+
B, M, N, K = x_val
236+
# a: (B, M, K) bfloat16
237+
# b: (B, N, K) bfloat16
238+
# c: (B, M, N) float32
239+
# pytorch reference: (B, M, N) float32
240+
required_memory = B * M * K * 2 + B * N * K * 2 + 2 * B * M * N * 4
241+
enough_memory = required_memory < DEVICE_TOTAL_MEMORY
242+
if not enough_memory:
243+
print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}")
244+
return enough_memory
245+
246+
247+
X_VALS = [x_val for x_val in X_VALS if is_enough_memory(x_val)]
248+
249+
250+
# Benchmark Performance
251+
@benchmark_suit.perf_report(
252+
benchmark_suit.Benchmark(
253+
# argument names to use as an x-axis for the plot
254+
x_names=['B', 'M', 'N', 'K'],
255+
# different possible values for `x_name`
256+
x_vals=X_VALS,
257+
line_arg='provider',
258+
# argument name whose value corresponds to a different line in the plot
259+
# possible values for `line_arg``
260+
line_vals=['triton', 'onednn'] + (['xetla'] if use_xetla else []),
261+
# label name for the lines
262+
line_names=['Triton', 'OneDNN'] + (['XeTLA'] if use_xetla else []),
263+
# line styles
264+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
265+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
266+
plot_name='matmul-tensor-desc-performance',
267+
# name for the plot. Used also as a file name for saving the plot.
268+
args={},
269+
))
270+
def benchmark(B, M, N, K, provider):
271+
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
272+
273+
torch.manual_seed(0)
274+
a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16)
275+
b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)
276+
277+
quantiles = [0.5, 0.0, 1.0]
278+
279+
torch_a = a
280+
if TRANSPOSE_A:
281+
torch_a = torch.transpose(torch_a, -2, -1)
282+
283+
torch_b = b
284+
if TRANSPOSE_B:
285+
torch_b = torch.transpose(torch_b, -2, -1)
286+
287+
if provider == 'onednn':
288+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10,
289+
n_repeat=10, quantiles=quantiles)
290+
elif provider == 'triton':
291+
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
292+
if len(a.shape) == 3:
293+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
294+
else:
295+
assert len(a.shape) == 2, 'Expecting shape of length 2'
296+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
297+
triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
298+
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
299+
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
300+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
301+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
302+
quantiles=quantiles)
303+
elif provider == 'xetla':
304+
if B == 1:
305+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
306+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
307+
else:
308+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
309+
cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32)
310+
name = f'gemm_shape_{B}_{M}_{K}_{N}'
311+
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
312+
# better performance.
313+
if (B, M, N, K) == (1, 3072, 3072, 4096):
314+
name = 'gemm_streamk_shape_3072_4096_3072'
315+
func = getattr(xetla_kernel, name)
316+
317+
def xetla_func_with_acc_allocation():
318+
# allocating `acc` matrix on every function call, to be as similar as
319+
# possible to the triton kernel, which also does this on every call.
320+
if B == 1:
321+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
322+
else:
323+
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
324+
return func(a, b, c, acc, cnt)
325+
326+
xetla_fn = xetla_func_with_acc_allocation
327+
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
328+
329+
# benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
330+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
331+
quantiles=quantiles)
332+
else:
333+
raise NotImplementedError(f'Unsupported provider {provider}')
334+
335+
tflops = lambda ms: 2 * B * M * N * K * (1e-12) / (ms * 1e-3)
336+
gbps = lambda ms: B * (2 * (M * K + K * N) + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3)
337+
338+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
339+
340+
341+
if __name__ == '__main__':
342+
benchmark.run(show_plots=False, print_data=True)

scripts/test-triton.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ run_tutorial_tests() {
252252
run_tutorial_test "06-fused-attention"
253253
run_tutorial_test "07-extern-functions"
254254
run_tutorial_test "08-grouped-gemm"
255-
TRITON_TEST_REPORTS=false run_tutorial_test "09-persistent-matmul"
255+
TRITON_TEST_REPORTS=false run_tutorial_test "09-persistent-matmul"
256256
run_tutorial_test "10-experimental-block-pointer"
257257
run_tutorial_test "10i-experimental-block-pointer"
258258

@@ -292,6 +292,9 @@ run_benchmark_gemm() {
292292

293293
echo "GEMM with tensor of pointer:"
294294
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/gemm_tensor_of_ptr_benchmark.py
295+
296+
echo "GEMM with tensor descriptor:"
297+
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py
295298
}
296299

297300
run_benchmark_attention() {

0 commit comments

Comments
 (0)