Skip to content

Commit b9925d9

Browse files
authored
Support AMD-specific autotune parameters: waves_per_eu and matrix_instr_nonkdim (#1162)
1 parent eecc471 commit b9925d9

File tree

6 files changed

+138
-0
lines changed

6 files changed

+138
-0
lines changed

helion/_compat.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import functools
5+
import re
56
from typing import Any
67
from typing import Callable
78
from typing import cast
@@ -286,3 +287,22 @@ def warps_to_threads(num_warps: int) -> int:
286287
)
287288
return num_warps * (props.warp_size or 32)
288289
return num_warps * 32
290+
291+
292+
@functools.cache
293+
def supports_amd_cdna_tunables() -> bool:
294+
if torch.version.hip is None or not torch.cuda.is_available():
295+
return False
296+
try:
297+
props = torch.cuda.get_device_properties(torch.cuda.current_device())
298+
arch = getattr(props, "gcnArchName", None)
299+
if arch is None:
300+
return False
301+
# Extract base architecture (e.g., "gfx942" from "gfx942:sramecc+:xnack-")
302+
# CDNA architectures are gfx908 and above but less than gfx1000
303+
# Reference: https://llvm.org/docs/AMDGPUUsage.html
304+
base_arch = arch.split(":")[0]
305+
match = re.match(r"gfx([0-9a-f]{3})", base_arch)
306+
return match is not None and int(match.group(1), 16) >= 0x908
307+
except Exception:
308+
return False

helion/_compiler/device_function.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ def codegen_function_call(self) -> ast.AST:
670670
if x.startswith("_triton_config_")
671671
]
672672
)
673+
for key in ("waves_per_eu", "matrix_instr_nonkdim"):
674+
if key in self.config:
675+
args.append(f"{key}={self.config[key]}")
673676
pid = self.pid
674677
assert pid is not None
675678
# TODO(jansel): we should run CSE this statement

helion/_testing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import triton
2525

2626
from ._compat import get_tensor_descriptor_fn_name
27+
from ._compat import supports_amd_cdna_tunables
2728
from ._utils import counters
2829
from .autotuner.benchmarking import compute_repeat
2930
from .autotuner.benchmarking import interleaved_bench
@@ -37,6 +38,13 @@
3738
from .runtime.kernel import Kernel
3839

3940

41+
def _strip_amd_launcher_args(value: str) -> str:
42+
if not supports_amd_cdna_tunables():
43+
return value
44+
value = re.sub(r", waves_per_eu=\d+", "", value)
45+
return re.sub(r", matrix_instr_nonkdim=\d+", "", value)
46+
47+
4048
def _get_triton_backend() -> str | None:
4149
try:
4250
# pyrefly: ignore [missing-attribute]
@@ -130,6 +138,13 @@ def skipIfRocm(reason: str) -> Callable[[Callable], Callable]:
130138
return unittest.skipIf(torch.version.hip is not None, reason)
131139

132140

141+
def skipUnlessAMDCDNA(reason: str) -> Callable[[Callable], Callable]:
142+
"""Skip test unless running on AMD CDNA architecture."""
143+
from helion._compat import supports_amd_cdna_tunables
144+
145+
return unittest.skipUnless(supports_amd_cdna_tunables(), reason)
146+
147+
133148
def skipIfXPU(reason: str) -> Callable[[Callable], Callable]:
134149
"""Skip test if running with Intel XPU"""
135150
return unittest.skipIf(torch.xpu.is_available(), reason)
@@ -1029,7 +1044,9 @@ def assertExpectedJournal(self, value: str) -> None:
10291044
Note:
10301045
Use EXPECTTEST_ACCEPT=1 environment variable to update expected outputs.
10311046
"""
1047+
value = _strip_amd_launcher_args(value)
10321048
value, expected = self._expected_journal.lookup(self.id(), value)
1049+
expected = _strip_amd_launcher_args(expected)
10331050
self.assertMultiLineEqual(
10341051
value,
10351052
expected,

helion/autotuner/config_spec.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from torch._inductor.runtime.runtime_utils import next_power_of_2
1010

11+
from .._compat import supports_amd_cdna_tunables
1112
from .._compat import supports_tensor_descriptor
1213
from ..exc import InvalidConfig
1314
from .block_id_sequence import BlockIdSequence
@@ -34,6 +35,7 @@
3435

3536
DEFAULT_NUM_WARPS = 4
3637
DEFAULT_NUM_STAGES = 1
38+
AMD_CDNA_TUNABLES = ("waves_per_eu", "matrix_instr_nonkdim")
3739
VALID_KEYS: frozenset[str] = frozenset(
3840
[
3941
"block_sizes",
@@ -52,10 +54,13 @@
5254
"pid_type",
5355
"indexing",
5456
"load_eviction_policies",
57+
*AMD_CDNA_TUNABLES,
5558
]
5659
)
5760
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
5861
VALID_EVICTION_POLICIES = ("", "first", "last")
62+
VALID_WAVES_PER_EU = (1, 2, 3, 4)
63+
VALID_MATRIX_INSTR_NONKDIM = (0, 16, 32)
5964

6065

6166
@dataclasses.dataclass
@@ -112,6 +117,20 @@ class ConfigSpec:
112117
length=0,
113118
)
114119
)
120+
waves_per_eu: ConfigSpecFragment | None = dataclasses.field(
121+
default_factory=lambda: (
122+
EnumFragment(choices=VALID_WAVES_PER_EU)
123+
if supports_amd_cdna_tunables()
124+
else None
125+
)
126+
)
127+
matrix_instr_nonkdim: ConfigSpecFragment | None = dataclasses.field(
128+
default_factory=lambda: (
129+
EnumFragment(choices=VALID_MATRIX_INSTR_NONKDIM)
130+
if supports_amd_cdna_tunables()
131+
else None
132+
)
133+
)
115134

116135
@staticmethod
117136
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -226,6 +245,12 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
226245
"load_eviction_policies", self.load_eviction_policies.default()
227246
)
228247
config.setdefault("indexing", self.indexing.default())
248+
for key in AMD_CDNA_TUNABLES:
249+
if (fragment := getattr(self, key)) is not None:
250+
config.setdefault(key, fragment.default())
251+
elif key in config:
252+
raise InvalidConfig(f"{key} is not supported on this target hardware")
253+
229254
# TODO(jansel): include num_ctas and max_nreg
230255

231256
for name, values in (("pid_type", VALID_PID_TYPES),):

test/test_amd_cdna.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
3+
from unittest.mock import patch
4+
5+
import torch
6+
7+
import helion
8+
from helion._compiler.compile_environment import CompileEnvironment
9+
from helion._testing import DEVICE
10+
from helion._testing import TestCase
11+
from helion._testing import code_and_output
12+
from helion._testing import skipUnlessAMDCDNA
13+
import helion.language as hl
14+
15+
16+
class TestAMDCDNA(TestCase):
17+
@skipUnlessAMDCDNA("Test requires AMD CDNA GPU (MI200/MI300 series)")
18+
def test_amd_cdna_tunables_in_kernel(self) -> None:
19+
"""Test that AMD CDNA tunables are supported."""
20+
21+
@helion.kernel(
22+
autotune_effort="none",
23+
config=helion.Config(
24+
block_sizes=[32, 32],
25+
waves_per_eu=2,
26+
matrix_instr_nonkdim=16,
27+
),
28+
)
29+
def add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
30+
result = torch.empty_like(x)
31+
for tile in hl.tile(x.shape):
32+
result[tile] = x[tile] + y[tile]
33+
return result
34+
35+
x = torch.randn(128, 128, device=DEVICE, dtype=torch.float32)
36+
y = torch.randn(128, 128, device=DEVICE, dtype=torch.float32)
37+
38+
code, result = code_and_output(add_kernel, (x, y))
39+
expected = x + y
40+
41+
torch.testing.assert_close(result, expected)
42+
43+
# Verify that the tunables are passed to Triton
44+
self.assertIn("waves_per_eu=2", code)
45+
self.assertIn("matrix_instr_nonkdim=16", code)
46+
47+
def test_amd_tunables_error_when_not_supported(self) -> None:
48+
"""Test that specifying AMD tunables on non-AMD hardware raises an error."""
49+
device = torch.device("cuda")
50+
settings = helion.Settings()
51+
52+
with patch(
53+
"helion.autotuner.config_spec.supports_amd_cdna_tunables",
54+
return_value=False,
55+
):
56+
env = CompileEnvironment(device, settings)
57+
58+
config = helion.Config(waves_per_eu=2)
59+
with self.assertRaisesRegex(
60+
helion.exc.InvalidConfig,
61+
"waves_per_eu is not supported on this target hardware",
62+
):
63+
env.config_spec.normalize(config)
64+
65+
config = helion.Config(matrix_instr_nonkdim=16)
66+
with self.assertRaisesRegex(
67+
helion.exc.InvalidConfig,
68+
"matrix_instr_nonkdim is not supported on this target hardware",
69+
):
70+
env.config_spec.normalize(config)

test/test_examples_dist.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from helion._testing import TestCase
1313
from helion._testing import code_and_output
1414
from helion._testing import import_path
15+
from helion._testing import skipIfRocm
1516

1617

1718
@instantiate_parametrized_tests
@@ -43,6 +44,7 @@ def _init_process(self):
4344
)
4445
torch.manual_seed(42 + self.rank)
4546

47+
@skipIfRocm("Distributed example requires CUDA/NCCL")
4648
@skip_if_lt_x_gpu(4)
4749
def test_all_gather_matmul(self):
4850
self._init_process()
@@ -100,6 +102,7 @@ def test_all_gather_matmul(self):
100102
torch.cuda.current_stream().wait_stream(backend_stream)
101103
dist.destroy_process_group()
102104

105+
@skipIfRocm("Distributed example requires CUDA/NCCL")
103106
@skip_if_lt_x_gpu(4)
104107
def test_all_reduce(self):
105108
self._init_process()

0 commit comments

Comments
 (0)