From 9a3322a6db8fcc5a4bf3403cc116260d4ff7fd72 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 16 Jun 2025 14:12:36 +0000 Subject: [PATCH 1/4] Add block ptr test for dot product with transpose --- python/test/unit/intel/test_block_load.py | 143 +++++++++++++++++++++- 1 file changed, 142 insertions(+), 1 deletion(-) diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 9c5b26b5a1..7e021a3dd9 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -1,8 +1,10 @@ import pytest import torch import pathlib +from functools import partial import triton +import triton.language as tl from triton._internal_testing import is_xpu @@ -74,5 +76,144 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa kernel = triton.compile(str(temp_file)) kernel[(1, 1, 1)](a, x, b, y) - #import pdb; pdb.set_trace() assert torch.equal(a, x) and torch.equal(b.T if transpose else b, y) + + +@pytest.mark.parametrize("BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K", + [[256, 256, 32], [256, 64, 32], [64, 256, 32], [64, 128, 32], [64, 64, 32], [32, 32, 32], + [32, 32, 16], [16, 16, 16], [8, 32, 16], [8, 512, 64]]) +@pytest.mark.parametrize("GROUP_SIZE_M", [4, 1]) +@pytest.mark.parametrize("TRANSPOSE_A", [True, False]) +@pytest.mark.parametrize("TRANSPOSE_B", [True, False]) +@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend") +@pytest.mark.xfail( + not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io'] + and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']), + reason="Block loads not supported on this architecture") +def test_block_load_dot_product(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, TRANSPOSE_A, TRANSPOSE_B, + device): + if GROUP_SIZE_M == 1 and (BLOCK_SIZE_M > 64 or BLOCK_SIZE_N > 64): + # skip large block sizes as they will be too slow + pytest.skip("Skipping slow combinations") + + @triton.jit + def matmul_kernel_with_block_pointers( + # Pointers to matrices + a_ptr, b_ptr, #bias_ptr, + c_ptr, + # Matrix dimensions + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am: tl.constexpr, stride_ak: tl.constexpr, # + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # + stride_cm: tl.constexpr, stride_cn: tl.constexpr, BIAS_REQD: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See the matrix multiplication tutorial for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + #tl.device_print("pid", pid_m) + + # ---------------------------------------------------------- + # Create block pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction and accumulate. + # See above `Make a Block Pointer` section for details. + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), + order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(1, 0)) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block. + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + # Load with boundary checks, no need to calculate the mask manually. + # For better performance, you may remove some axis from the boundary + # check, if you can guarantee that the access is always in-bound in + # that axis. + # See above `Load/Store a Block Pointer` section for details. + a = tl.load(a_block_ptr, boundary_check=(0, 1)) + b = tl.load(b_block_ptr, boundary_check=(0, 1)) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the block pointer to the next K block. + # See above `Advance a Block Pointer` section for details. + a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) + c = accumulator.to(tl.float32) + # add bias to accumulator + + #if BIAS_REQD: + # offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + # bias = tl.load(bias_ptr + offs_yn, mask=offs_yn < N, other=0.0).to(tl.float32) + # c += bias[None, :] + # ---------------------------------------------------------------- + # Write back the block of the output matrix C with boundary checks. + # See above `Load/Store a Block Pointer` section for details. + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + tl.store(c_block_ptr, c.to(tl.float16), boundary_check=(0, 1)) + + def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False): + if transpose_x: + K, M = X.shape + Xstride0, Xstride1 = X.stride(1), X.stride(0) + else: + M, K = X.shape + Xstride0, Xstride1 = X.stride(0), X.stride(1) + if transpose_y: + N, _ = Y.shape + Wstride0, Wstride1 = Y.stride(1), Y.stride(0) + else: + _, N = Y.shape + Wstride0, Wstride1 = Y.stride(0), Y.stride(1) + # Allocates output. + Z = torch.empty((M, N), device=X.device, dtype=X.dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + matmul_kernel_with_block_pointers[grid](X, Y, Z, M, N, K, Xstride0, Xstride1, Wstride0, Wstride1, Z.stride(0), + Z.stride(1), BIAS_REQD=b is not None, BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M) + + return Z + + M = 512 + K = 64 + N = 512 + dtype = torch.float16 + torch.manual_seed(0) + + X = torch.randn((M, K) if not TRANSPOSE_A else (K, M), device=device, dtype=dtype, requires_grad=False) + Y = torch.randn((K, N) if not TRANSPOSE_B else (N, K), device=device, dtype=dtype, requires_grad=False) + + fn_tor = partial(torch.mm, X if not TRANSPOSE_A else X.T, Y if not TRANSPOSE_B else Y.T) + fn_tri = partial(triton_mm, X, Y, transpose_x=TRANSPOSE_A, transpose_y=TRANSPOSE_B) + + rtol = 1e-3 + result_tor = fn_tor() + result_tri = fn_tri() + torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=rtol) From d852a35465a96fffdbd3bba6814a9ae4ac906ca7 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 16 Jun 2025 19:39:46 +0000 Subject: [PATCH 2/4] address review comments --- python/test/unit/intel/test_block_load.py | 27 +++++++---------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 7e021a3dd9..948e957fcd 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -89,18 +89,17 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa @pytest.mark.xfail( not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io'] and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']), - reason="Block loads not supported on this architecture") + reason="Block loads and/or DPAS not supported on this architecture") def test_block_load_dot_product(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, TRANSPOSE_A, TRANSPOSE_B, device): if GROUP_SIZE_M == 1 and (BLOCK_SIZE_M > 64 or BLOCK_SIZE_N > 64): # skip large block sizes as they will be too slow - pytest.skip("Skipping slow combinations") + pytest.xfail("Skipping slow combinations") @triton.jit def matmul_kernel_with_block_pointers( # Pointers to matrices - a_ptr, b_ptr, #bias_ptr, - c_ptr, + a_ptr, b_ptr, c_ptr, # Matrix dimensions M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, # The stride variables represent how much to increase the ptr by when moving by 1 @@ -108,10 +107,9 @@ def matmul_kernel_with_block_pointers( # by to get the element one row down (A has M rows). stride_am: tl.constexpr, stride_ak: tl.constexpr, # stride_bk: tl.constexpr, stride_bn: tl.constexpr, # - stride_cm: tl.constexpr, stride_cn: tl.constexpr, BIAS_REQD: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, # # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ @@ -162,15 +160,7 @@ def matmul_kernel_with_block_pointers( a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) c = accumulator.to(tl.float32) - # add bias to accumulator - - #if BIAS_REQD: - # offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - # bias = tl.load(bias_ptr + offs_yn, mask=offs_yn < N, other=0.0).to(tl.float32) - # c += bias[None, :] - # ---------------------------------------------------------------- - # Write back the block of the output matrix C with boundary checks. - # See above `Load/Store a Block Pointer` section for details. + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) @@ -195,9 +185,8 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False): grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel_with_block_pointers[grid](X, Y, Z, M, N, K, Xstride0, Xstride1, Wstride0, Wstride1, Z.stride(0), - Z.stride(1), BIAS_REQD=b is not None, BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, - GROUP_SIZE_M=GROUP_SIZE_M) + Z.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=GROUP_SIZE_M) return Z From 1e80f67efcdb3a75fbf2208a414ce8d9869c65b1 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 16 Jun 2025 19:40:38 +0000 Subject: [PATCH 3/4] remove device print --- python/test/unit/intel/test_block_load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 948e957fcd..54739d40d0 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -126,7 +126,6 @@ def matmul_kernel_with_block_pointers( group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m - #tl.device_print("pid", pid_m) # ---------------------------------------------------------- # Create block pointers for the first blocks of A and B. From 06f2a5a30061c262a6937fd381e5ff62c8f2beb1 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 16 Jun 2025 15:42:24 -0400 Subject: [PATCH 4/4] Update python/test/unit/intel/test_block_load.py Co-authored-by: Whitney Tsang --- python/test/unit/intel/test_block_load.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 54739d40d0..45ce9e5d8c 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -201,7 +201,6 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False): fn_tor = partial(torch.mm, X if not TRANSPOSE_A else X.T, Y if not TRANSPOSE_B else Y.T) fn_tri = partial(triton_mm, X, Y, transpose_x=TRANSPOSE_A, transpose_y=TRANSPOSE_B) - rtol = 1e-3 result_tor = fn_tor() result_tri = fn_tri() - torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=rtol) + torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)