|
| 1 | +# End-to-end tests to check the correctness of the pipeliner |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | +import numpy as np |
| 8 | + |
| 9 | + |
| 10 | +def is_cuda(): |
| 11 | + return triton.runtime.driver.active.get_current_target().backend == "cuda" |
| 12 | + |
| 13 | + |
| 14 | +def is_cuda_tma_available(): |
| 15 | + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 |
| 16 | + |
| 17 | + |
| 18 | +def is_hip(): |
| 19 | + return triton.runtime.driver.active.get_current_target().backend == "hip" |
| 20 | + |
| 21 | + |
| 22 | +def is_hip_mi200(): |
| 23 | + target = triton.runtime.driver.active.get_current_target() |
| 24 | + return target.backend == 'hip' and target.arch == 'gfx90a' |
| 25 | + |
| 26 | + |
| 27 | +def check_capabilities(): |
| 28 | + if is_cuda(): |
| 29 | + cc = torch.cuda.get_device_capability() |
| 30 | + if cc[0] < 8: |
| 31 | + pytest.skip("CUDA 8.0+ required") |
| 32 | + |
| 33 | + |
| 34 | +@triton.jit |
| 35 | +def matmul_kernel( # |
| 36 | + a_ptr, b_ptr, output_ptr, # |
| 37 | + M, N, K, # |
| 38 | + stride_am, stride_ak, # |
| 39 | + stride_bk, stride_bn, # |
| 40 | + stride_cm, stride_cn, # |
| 41 | + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # |
| 42 | + NUM_STAGES: tl.constexpr): |
| 43 | + pid = tl.program_id(axis=0) |
| 44 | + num_pid_m = tl.cdiv(M, BLOCK_M) |
| 45 | + pid_m = pid % num_pid_m |
| 46 | + pid_n = pid // num_pid_m |
| 47 | + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M |
| 48 | + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N |
| 49 | + offs_k = tl.arange(0, BLOCK_K) |
| 50 | + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 51 | + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
| 52 | + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| 53 | + for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): |
| 54 | + mask_a = (offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_K < K) |
| 55 | + mask_b = ((offs_k[:, None] + k * BLOCK_K) < K) & (offs_bn[None, :] < N) |
| 56 | + a = tl.load(a_ptrs, mask=mask_a, other=0) |
| 57 | + b = tl.load(b_ptrs, mask=mask_b, other=0) |
| 58 | + accumulator = tl.dot(a, b, acc=accumulator) |
| 59 | + a_ptrs += BLOCK_K * stride_ak |
| 60 | + b_ptrs += BLOCK_K * stride_bk |
| 61 | + accumulator = accumulator.to(tl.float16) |
| 62 | + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| 63 | + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| 64 | + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
| 65 | + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] |
| 66 | + tl.store(output_ptrs, accumulator, mask=mask_c) |
| 67 | + |
| 68 | + |
| 69 | +@triton.jit |
| 70 | +def matmul_kernel_tma( # |
| 71 | + a_ptr, b_ptr, output_ptr, # |
| 72 | + M, N, K, # |
| 73 | + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # |
| 74 | + NUM_STAGES: tl.constexpr): |
| 75 | + pid = tl.program_id(axis=0) |
| 76 | + num_pid_m = tl.cdiv(M, BLOCK_M) |
| 77 | + pid_m = pid % num_pid_m |
| 78 | + pid_n = pid // num_pid_m |
| 79 | + offs_am = (pid_m * BLOCK_M) % M |
| 80 | + offs_bn = (pid_n * BLOCK_N) % N |
| 81 | + offs_am = tl.multiple_of(offs_am, BLOCK_M) |
| 82 | + offs_bn = tl.multiple_of(offs_bn, BLOCK_N) |
| 83 | + offs_k = 0 |
| 84 | + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| 85 | + for _ in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): |
| 86 | + a = tl._experimental_descriptor_load(a_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], tl.float16) |
| 87 | + b = tl._experimental_descriptor_load(b_ptr, [offs_k, offs_bn], [BLOCK_K, BLOCK_N], tl.float16) |
| 88 | + accumulator = tl.dot(a, b, acc=accumulator) |
| 89 | + offs_k += BLOCK_K |
| 90 | + accumulator = accumulator.to(tl.float16) |
| 91 | + tl._experimental_descriptor_store(output_ptr, accumulator, [offs_am, offs_bn]) |
| 92 | + |
| 93 | + |
| 94 | +@triton.jit |
| 95 | +def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr): |
| 96 | + pid = tl.program_id(axis=0) |
| 97 | + block_start = pid * BLOCK_SIZE * num_blocks |
| 98 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 99 | + for _ in tl.range(0, num_blocks, num_stages=NUM_STAGES): |
| 100 | + mask = offsets < n_elements |
| 101 | + x = tl.load(a_ptr + offsets, mask=mask) |
| 102 | + y = tl.load(b_ptr + offsets, mask=mask) |
| 103 | + output = x + y |
| 104 | + tl.store(output_ptr + offsets, output, mask=mask) |
| 105 | + offsets += BLOCK_SIZE |
| 106 | + |
| 107 | + |
| 108 | +def test_pipeline_matmul(device): |
| 109 | + check_capabilities() |
| 110 | + M, N, K = 512, 512, 128 |
| 111 | + BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 |
| 112 | + NUM_STAGES = 4 |
| 113 | + a = torch.randn(M, K, device=device, dtype=torch.float16) |
| 114 | + b = torch.randn(K, N, device=device, dtype=torch.float16) |
| 115 | + output = torch.empty((M, N), dtype=torch.float16, device=device) |
| 116 | + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) |
| 117 | + if is_cuda_tma_available(): |
| 118 | + TMA_SIZE = 128 |
| 119 | + |
| 120 | + desc_a = np.empty(TMA_SIZE, dtype=np.int8) |
| 121 | + desc_b = np.empty(TMA_SIZE, dtype=np.int8) |
| 122 | + desc_output = np.empty(TMA_SIZE, dtype=np.int8) |
| 123 | + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), M, K, BLOCK_M, BLOCK_K, |
| 124 | + a.element_size(), desc_a) |
| 125 | + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), K, N, BLOCK_K, BLOCK_N, |
| 126 | + b.element_size(), desc_b) |
| 127 | + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(output.data_ptr(), M, N, BLOCK_M, BLOCK_N, |
| 128 | + output.element_size(), desc_output) |
| 129 | + |
| 130 | + a_tma = torch.tensor(desc_a, device=device) |
| 131 | + b_tma = torch.tensor(desc_b, device=device) |
| 132 | + output_tma = torch.tensor(desc_output, device=device) |
| 133 | + handler = matmul_kernel_tma[grid](a_tma, b_tma, output_tma, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, |
| 134 | + NUM_STAGES=NUM_STAGES) |
| 135 | + else: |
| 136 | + handler = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), |
| 137 | + output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, |
| 138 | + NUM_STAGES=NUM_STAGES) |
| 139 | + ref_out = torch.matmul(a, b) |
| 140 | + atol = 1e-2 if is_hip() else None |
| 141 | + # Bigger tolerance for AMD MI200 devices. |
| 142 | + # MI200 devices use reduced precision fp16 and bf16 and flush input and |
| 143 | + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices |
| 144 | + rtol = 1e-2 if is_hip_mi200() else None |
| 145 | + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol) |
| 146 | + if is_cuda(): |
| 147 | + ttgir = handler.asm["ttgir"] |
| 148 | + if is_cuda_tma_available(): |
| 149 | + assert ttgir.count("triton_nvidia_gpu.async_tma_copy_global_to_local") != 0, "async tma copy not found" |
| 150 | + assert ttgir.count(f"num = {NUM_STAGES} : i32") == 0, "num_stages not match" |
| 151 | + # a_tma, b_tma, output_tma, barriar |
| 152 | + assert ttgir.count("triton_gpu.local_alloc") == 4, "alloc number not match" |
| 153 | + assert ttgir.count("triton_nvidia_gpu.barrier_expect") != 0, "barrier_expect not found" |
| 154 | + assert ttgir.count("triton_nvidia_gpu.wait_barrier") != 0, "wait_barrier not found" |
| 155 | + assert ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found" |
| 156 | + else: |
| 157 | + # 1. check async |
| 158 | + assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found" |
| 159 | + # 2. check number of stages |
| 160 | + assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match" |
| 161 | + # 3. check alloc |
| 162 | + assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match" |
| 163 | + # 4. check dot |
| 164 | + cc = torch.cuda.get_device_capability() |
| 165 | + if cc[0] >= 9: |
| 166 | + ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found" |
| 167 | + else: |
| 168 | + ttgir.count("triton_gpu.dot") != 0, "dot not found" |
| 169 | + |
| 170 | + |
| 171 | +def test_pipeline_vecadd(device): |
| 172 | + check_capabilities() |
| 173 | + SIZE = 4096 |
| 174 | + NUM_BLOCKS = 4 |
| 175 | + BLOCK_SIZE = 256 |
| 176 | + NUM_STAGES = 3 |
| 177 | + a = torch.randn(SIZE, dtype=torch.float16, device=device) |
| 178 | + b = torch.randn(SIZE, dtype=torch.float16, device=device) |
| 179 | + output = torch.empty(SIZE, dtype=torch.float16, device=device) |
| 180 | + grid = (triton.cdiv(SIZE, NUM_BLOCKS * BLOCK_SIZE), 1) |
| 181 | + handler = vecadd_kernel[grid](a, b, output, SIZE, NUM_BLOCKS, BLOCK_SIZE, NUM_STAGES) |
| 182 | + ref_out = a + b |
| 183 | + torch.testing.assert_close(ref_out, output) |
| 184 | + if is_cuda(): |
| 185 | + ttgir = handler.asm["ttgir"] |
| 186 | + # 1. check async |
| 187 | + assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found" |
| 188 | + # 2. check number of stages |
| 189 | + assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match" |
| 190 | + # 3. check alloc |
| 191 | + assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match" |
0 commit comments