Skip to content

Commit e834fa7

Browse files
authored
Add Settings.autotune_force_persistent (#1130)
1 parent 5699050 commit e834fa7

File tree

4 files changed

+43
-0
lines changed

4 files changed

+43
-0
lines changed

docs/api/settings.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
106106
107107
Force autotuning even when explicit configs are provided. Default is ``False``. Controlled by ``HELION_FORCE_AUTOTUNE=1``.
108108
109+
.. autoattribute:: Settings.autotune_force_persistent
110+
111+
Restrict ``pid_type`` choices to the persistent strategies (``"persistent_blocked"`` or ``"persistent_interleaved"``).
112+
Default is ``False``. Enable globally with ``HELION_AUTOTUNE_FORCE_PERSISTENT=1`` or per kernel via ``@helion.kernel(..., autotune_force_persistent=True)``.
113+
109114
.. autoattribute:: Settings.autotune_log_level
110115
111116
Controls verbosity of autotuning output using Python logging levels:
@@ -258,6 +263,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
258263
| ``HELION_STATIC_SHAPES`` | ``static_shapes`` | Set to ``0``/``false`` to disable global static shape specialization. |
259264
| ``HELION_PERSISTENT_RESERVED_SMS`` | ``persistent_reserved_sms`` | Reserve this many streaming multiprocessors when launching persistent kernels (``0`` uses all available SMs). |
260265
| ``HELION_FORCE_AUTOTUNE`` | ``force_autotune`` | Force the autotuner to run even when explicit configs are provided. |
266+
| ``HELION_AUTOTUNE_FORCE_PERSISTENT`` | ``autotune_force_persistent`` | Restrict ``pid_type`` to persistent kernel strategies during config search. |
261267
| ``HELION_DISALLOW_AUTOTUNING`` | ``check_autotuning_disabled`` | Hard-disable autotuning; kernels must supply explicit configs when this is ``1``. |
262268
| ``HELION_AUTOTUNE_COMPILE_TIMEOUT`` | ``autotune_compile_timeout`` | Maximum seconds to wait for Triton compilation during autotuning. |
263269
| ``HELION_AUTOTUNE_LOG_LEVEL`` | ``autotune_log_level`` | Adjust logging verbosity; accepts names like ``INFO`` or numeric levels. |

helion/_compiler/compile_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9292
self.block_sizes: list[BlockSizeInfo] = []
9393
self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {}
9494
self.config_spec = ConfigSpec()
95+
if settings.autotune_force_persistent:
96+
for pid_type in ("flat", "xyz"):
97+
self.config_spec.disallow_pid_type(pid_type)
9598
self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = (
9699
collections.Counter()
97100
)

helion/runtime/settings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,13 @@ class _Settings:
285285
0,
286286
)
287287
)
288+
autotune_force_persistent: bool = dataclasses.field(
289+
default_factory=functools.partial(
290+
_env_get_bool,
291+
"HELION_AUTOTUNE_FORCE_PERSISTENT",
292+
False,
293+
)
294+
)
288295
autotune_log_level: int = dataclasses.field(default_factory=_get_autotune_log_level)
289296
autotune_log: str | None = dataclasses.field(default_factory=_get_autotune_log_path)
290297
autotune_compile_timeout: int = dataclasses.field(
@@ -412,6 +419,10 @@ class Settings(_Settings):
412419
"Number of streaming multiprocessors to reserve when launching persistent kernels. "
413420
"Set HELION_PERSISTENT_RESERVED_SMS=N (default 0) or pass persistent_reserved_sms=N to helion.kernel."
414421
),
422+
"autotune_force_persistent": (
423+
"If True, restrict pid_type choices to persistent kernels only during config selection. "
424+
"Set HELION_AUTOTUNE_FORCE_PERSISTENT=1 to force persistent kernel autotuning globally."
425+
),
415426
"autotune_log_level": (
416427
"Log level for autotuning using Python logging levels. Default is logging.INFO. "
417428
"Use HELION_AUTOTUNE_LOG_LEVEL to override or set 0 to disable output."

test/test_config_api.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22

33
import importlib
44
import inspect
5+
import os
56
import pickle
67
from typing import Any
78
import unittest
9+
from unittest.mock import patch
810

911
from hypothesis import given
1012
from hypothesis import settings
1113
from hypothesis import strategies as st
14+
import torch
1215

1316
import helion
17+
from helion._compiler.compile_environment import CompileEnvironment
1418
from helion._testing import TestCase
1519

1620

@@ -232,5 +236,24 @@ def test_pre_serialized_json_backward_compat(self) -> None:
232236
self.assertEqual(dict(reread), expected)
233237

234238

239+
class TestSettingsEnv(TestCase):
240+
def test_persistent_reserved_sms_env_var(self) -> None:
241+
with patch.dict(
242+
os.environ,
243+
{"HELION_PERSISTENT_RESERVED_SMS": "5"},
244+
clear=False,
245+
):
246+
settings = helion.Settings()
247+
self.assertEqual(settings.persistent_reserved_sms, 5)
248+
249+
def test_autotune_force_persistent_limits_config_spec(self) -> None:
250+
settings = helion.Settings(autotune_force_persistent=True)
251+
env = CompileEnvironment(torch.device("cpu"), settings)
252+
self.assertEqual(
253+
env.config_spec.allowed_pid_types,
254+
("persistent_blocked", "persistent_interleaved"),
255+
)
256+
257+
235258
if __name__ == "__main__":
236259
unittest.main()

0 commit comments

Comments
 (0)