|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from unittest.mock import patch |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +import helion |
| 8 | +from helion._compiler.compile_environment import CompileEnvironment |
| 9 | +from helion._testing import DEVICE |
| 10 | +from helion._testing import TestCase |
| 11 | +from helion._testing import code_and_output |
| 12 | +from helion._testing import skipUnlessAMDCDNA |
| 13 | +import helion.language as hl |
| 14 | + |
| 15 | + |
| 16 | +class TestAMDCDNA(TestCase): |
| 17 | + @skipUnlessAMDCDNA("Test requires AMD CDNA GPU (MI200/MI300 series)") |
| 18 | + def test_amd_cdna_tunables_in_kernel(self) -> None: |
| 19 | + """Test that AMD CDNA tunables are supported.""" |
| 20 | + |
| 21 | + @helion.kernel( |
| 22 | + autotune_effort="none", |
| 23 | + config=helion.Config( |
| 24 | + block_sizes=[32, 32], |
| 25 | + waves_per_eu=2, |
| 26 | + matrix_instr_nonkdim=16, |
| 27 | + ), |
| 28 | + ) |
| 29 | + def add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 30 | + result = torch.empty_like(x) |
| 31 | + for tile in hl.tile(x.shape): |
| 32 | + result[tile] = x[tile] + y[tile] |
| 33 | + return result |
| 34 | + |
| 35 | + x = torch.randn(128, 128, device=DEVICE, dtype=torch.float32) |
| 36 | + y = torch.randn(128, 128, device=DEVICE, dtype=torch.float32) |
| 37 | + |
| 38 | + code, result = code_and_output(add_kernel, (x, y)) |
| 39 | + expected = x + y |
| 40 | + |
| 41 | + torch.testing.assert_close(result, expected) |
| 42 | + |
| 43 | + # Verify that the tunables are passed to Triton |
| 44 | + self.assertIn("waves_per_eu=2", code) |
| 45 | + self.assertIn("matrix_instr_nonkdim=16", code) |
| 46 | + |
| 47 | + def test_amd_tunables_error_when_not_supported(self) -> None: |
| 48 | + """Test that specifying AMD tunables on non-AMD hardware raises an error.""" |
| 49 | + device = torch.device("cuda") |
| 50 | + settings = helion.Settings() |
| 51 | + |
| 52 | + with patch( |
| 53 | + "helion.autotuner.config_spec.supports_amd_cdna_tunables", |
| 54 | + return_value=False, |
| 55 | + ): |
| 56 | + env = CompileEnvironment(device, settings) |
| 57 | + |
| 58 | + config = helion.Config(waves_per_eu=2) |
| 59 | + with self.assertRaisesRegex( |
| 60 | + helion.exc.InvalidConfig, |
| 61 | + "waves_per_eu is not supported on this target hardware", |
| 62 | + ): |
| 63 | + env.config_spec.normalize(config) |
| 64 | + |
| 65 | + config = helion.Config(matrix_instr_nonkdim=16) |
| 66 | + with self.assertRaisesRegex( |
| 67 | + helion.exc.InvalidConfig, |
| 68 | + "matrix_instr_nonkdim is not supported on this target hardware", |
| 69 | + ): |
| 70 | + env.config_spec.normalize(config) |
0 commit comments