Skip to content

Commit a06add0

Browse files
authored
[TEST] Deprecate triton operators and add end to end pipeline tests (intel#4156)
1 parent 75b0321 commit a06add0

File tree

18 files changed

+193
-2437
lines changed

18 files changed

+193
-2437
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,6 @@ jobs:
257257
cd python/test/unit
258258
python3 -m pytest -s -n 16 -m interpreter language/test_core.py language/test_standard.py \
259259
language/test_random.py language/test_block_pointer.py language/test_subprocess.py \
260-
operators/test_flash_attention.py::test_op \
261260
../../tutorials/06-fused-attention.py::test_op --device cpu
262261
- name: Run C++ unittests
263262
run: |
@@ -384,7 +383,7 @@ jobs:
384383
cd python/test/unit
385384
## test_subprocess.py is flaky on the AMD CI.
386385
## TODO (lixun) find a solution and re-enable it.
387-
pytest --capture=tee-sys -rfs -n 32 language operators \
386+
pytest --capture=tee-sys -rfs -n 32 language \
388387
hopper/test_mixed_io.py \
389388
hopper/test_gemm.py \
390389
hopper/test_tma_store_gemm.py \

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ jobs:
293293
cd python/test/unit
294294
python3 -m pytest -s -n 16 -m interpreter language/test_core.py language/test_standard.py \
295295
language/test_random.py language/test_block_pointer.py language/test_subprocess.py \
296-
operators/test_flash_attention.py::test_op \
297296
../../tutorials/06-fused-attention.py::test_op --device cpu
298297

299298
- &run-cpp-unittests-step
@@ -388,7 +387,7 @@ jobs:
388387
cd python/test/unit
389388
## test_subprocess.py is flaky on the AMD CI.
390389
## TODO (lixun) find a solution and re-enable it.
391-
pytest --capture=tee-sys -rfs -n 32 language operators \
390+
pytest --capture=tee-sys -rfs -n 32 language \
392391
hopper/test_mixed_io.py \
393392
hopper/test_gemm.py \
394393
hopper/test_tma_store_gemm.py \

python/setup.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,6 @@ def get_packages():
557557
"triton/language/extra",
558558
"triton/language/extra/cuda",
559559
"triton/language/extra/hip",
560-
"triton/ops",
561-
"triton/ops/blocksparse",
562560
"triton/runtime",
563561
"triton/backends",
564562
"triton/tools",
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# End-to-end tests to check the correctness of the pipeliner
2+
3+
import pytest
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
import numpy as np
8+
9+
10+
def is_cuda():
11+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
12+
13+
14+
def is_cuda_tma_available():
15+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
16+
17+
18+
def is_hip():
19+
return triton.runtime.driver.active.get_current_target().backend == "hip"
20+
21+
22+
def is_hip_mi200():
23+
target = triton.runtime.driver.active.get_current_target()
24+
return target.backend == 'hip' and target.arch == 'gfx90a'
25+
26+
27+
def check_capabilities():
28+
if is_cuda():
29+
cc = torch.cuda.get_device_capability()
30+
if cc[0] < 8:
31+
pytest.skip("CUDA 8.0+ required")
32+
33+
34+
@triton.jit
35+
def matmul_kernel( #
36+
a_ptr, b_ptr, output_ptr, #
37+
M, N, K, #
38+
stride_am, stride_ak, #
39+
stride_bk, stride_bn, #
40+
stride_cm, stride_cn, #
41+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
42+
NUM_STAGES: tl.constexpr):
43+
pid = tl.program_id(axis=0)
44+
num_pid_m = tl.cdiv(M, BLOCK_M)
45+
pid_m = pid % num_pid_m
46+
pid_n = pid // num_pid_m
47+
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
48+
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
49+
offs_k = tl.arange(0, BLOCK_K)
50+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
51+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
52+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
53+
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
54+
mask_a = (offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_K < K)
55+
mask_b = ((offs_k[:, None] + k * BLOCK_K) < K) & (offs_bn[None, :] < N)
56+
a = tl.load(a_ptrs, mask=mask_a, other=0)
57+
b = tl.load(b_ptrs, mask=mask_b, other=0)
58+
accumulator = tl.dot(a, b, acc=accumulator)
59+
a_ptrs += BLOCK_K * stride_ak
60+
b_ptrs += BLOCK_K * stride_bk
61+
accumulator = accumulator.to(tl.float16)
62+
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
63+
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
64+
mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
65+
output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
66+
tl.store(output_ptrs, accumulator, mask=mask_c)
67+
68+
69+
@triton.jit
70+
def matmul_kernel_tma( #
71+
a_ptr, b_ptr, output_ptr, #
72+
M, N, K, #
73+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
74+
NUM_STAGES: tl.constexpr):
75+
pid = tl.program_id(axis=0)
76+
num_pid_m = tl.cdiv(M, BLOCK_M)
77+
pid_m = pid % num_pid_m
78+
pid_n = pid // num_pid_m
79+
offs_am = (pid_m * BLOCK_M) % M
80+
offs_bn = (pid_n * BLOCK_N) % N
81+
offs_am = tl.multiple_of(offs_am, BLOCK_M)
82+
offs_bn = tl.multiple_of(offs_bn, BLOCK_N)
83+
offs_k = 0
84+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
85+
for _ in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
86+
a = tl._experimental_descriptor_load(a_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], tl.float16)
87+
b = tl._experimental_descriptor_load(b_ptr, [offs_k, offs_bn], [BLOCK_K, BLOCK_N], tl.float16)
88+
accumulator = tl.dot(a, b, acc=accumulator)
89+
offs_k += BLOCK_K
90+
accumulator = accumulator.to(tl.float16)
91+
tl._experimental_descriptor_store(output_ptr, accumulator, [offs_am, offs_bn])
92+
93+
94+
@triton.jit
95+
def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr):
96+
pid = tl.program_id(axis=0)
97+
block_start = pid * BLOCK_SIZE * num_blocks
98+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
99+
for _ in tl.range(0, num_blocks, num_stages=NUM_STAGES):
100+
mask = offsets < n_elements
101+
x = tl.load(a_ptr + offsets, mask=mask)
102+
y = tl.load(b_ptr + offsets, mask=mask)
103+
output = x + y
104+
tl.store(output_ptr + offsets, output, mask=mask)
105+
offsets += BLOCK_SIZE
106+
107+
108+
def test_pipeline_matmul(device):
109+
check_capabilities()
110+
M, N, K = 512, 512, 128
111+
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32
112+
NUM_STAGES = 4
113+
a = torch.randn(M, K, device=device, dtype=torch.float16)
114+
b = torch.randn(K, N, device=device, dtype=torch.float16)
115+
output = torch.empty((M, N), dtype=torch.float16, device=device)
116+
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
117+
if is_cuda_tma_available():
118+
TMA_SIZE = 128
119+
120+
desc_a = np.empty(TMA_SIZE, dtype=np.int8)
121+
desc_b = np.empty(TMA_SIZE, dtype=np.int8)
122+
desc_output = np.empty(TMA_SIZE, dtype=np.int8)
123+
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), M, K, BLOCK_M, BLOCK_K,
124+
a.element_size(), desc_a)
125+
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), K, N, BLOCK_K, BLOCK_N,
126+
b.element_size(), desc_b)
127+
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(output.data_ptr(), M, N, BLOCK_M, BLOCK_N,
128+
output.element_size(), desc_output)
129+
130+
a_tma = torch.tensor(desc_a, device=device)
131+
b_tma = torch.tensor(desc_b, device=device)
132+
output_tma = torch.tensor(desc_output, device=device)
133+
handler = matmul_kernel_tma[grid](a_tma, b_tma, output_tma, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
134+
NUM_STAGES=NUM_STAGES)
135+
else:
136+
handler = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1),
137+
output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
138+
NUM_STAGES=NUM_STAGES)
139+
ref_out = torch.matmul(a, b)
140+
atol = 1e-2 if is_hip() else None
141+
# Bigger tolerance for AMD MI200 devices.
142+
# MI200 devices use reduced precision fp16 and bf16 and flush input and
143+
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
144+
rtol = 1e-2 if is_hip_mi200() else None
145+
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)
146+
if is_cuda():
147+
ttgir = handler.asm["ttgir"]
148+
if is_cuda_tma_available():
149+
assert ttgir.count("triton_nvidia_gpu.async_tma_copy_global_to_local") != 0, "async tma copy not found"
150+
assert ttgir.count(f"num = {NUM_STAGES} : i32") == 0, "num_stages not match"
151+
# a_tma, b_tma, output_tma, barriar
152+
assert ttgir.count("triton_gpu.local_alloc") == 4, "alloc number not match"
153+
assert ttgir.count("triton_nvidia_gpu.barrier_expect") != 0, "barrier_expect not found"
154+
assert ttgir.count("triton_nvidia_gpu.wait_barrier") != 0, "wait_barrier not found"
155+
assert ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found"
156+
else:
157+
# 1. check async
158+
assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found"
159+
# 2. check number of stages
160+
assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match"
161+
# 3. check alloc
162+
assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match"
163+
# 4. check dot
164+
cc = torch.cuda.get_device_capability()
165+
if cc[0] >= 9:
166+
ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found"
167+
else:
168+
ttgir.count("triton_gpu.dot") != 0, "dot not found"
169+
170+
171+
def test_pipeline_vecadd(device):
172+
check_capabilities()
173+
SIZE = 4096
174+
NUM_BLOCKS = 4
175+
BLOCK_SIZE = 256
176+
NUM_STAGES = 3
177+
a = torch.randn(SIZE, dtype=torch.float16, device=device)
178+
b = torch.randn(SIZE, dtype=torch.float16, device=device)
179+
output = torch.empty(SIZE, dtype=torch.float16, device=device)
180+
grid = (triton.cdiv(SIZE, NUM_BLOCKS * BLOCK_SIZE), 1)
181+
handler = vecadd_kernel[grid](a, b, output, SIZE, NUM_BLOCKS, BLOCK_SIZE, NUM_STAGES)
182+
ref_out = a + b
183+
torch.testing.assert_close(ref_out, output)
184+
if is_cuda():
185+
ttgir = handler.asm["ttgir"]
186+
# 1. check async
187+
assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found"
188+
# 2. check number of stages
189+
assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match"
190+
# 3. check alloc
191+
assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match"

python/test/unit/operators/conftest.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)