Skip to content

Commit 5699050

Browse files
authored
Add Settings.persistent_reserved_sms (#1129)
1 parent fe970d6 commit 5699050

File tree

5 files changed

+60
-7
lines changed

5 files changed

+60
-7
lines changed

docs/api/settings.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
9292
9393
When enabled, tensor shapes are treated as compile-time constants for optimization. Default is ``True``.
9494
Set ``HELION_STATIC_SHAPES=0`` the default if you need a compiled kernel instance to serve many shape variants.
95+
96+
.. autoattribute:: Settings.persistent_reserved_sms
97+
98+
Reserve this many streaming multiprocessors when launching persistent kernels. Default is ``0`` (use all SMs).
99+
Configure globally with ``HELION_PERSISTENT_RESERVED_SMS`` or per-kernel via ``@helion.kernel(..., persistent_reserved_sms=N)``.
95100
```
96101

97102
### Autotuning Settings
@@ -251,6 +256,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
251256
| ``TRITON_F32_DEFAULT`` | ``dot_precision`` | Sets default floating-point precision for Triton dot products (``"tf32"``, ``"tf32x3"``, ``"ieee"``). |
252257
| ``HELION_INDEX_DTYPE`` | ``index_dtype`` | Choose the default index dtype (accepts any ``torch.<dtype>`` name, e.g. ``int64``). |
253258
| ``HELION_STATIC_SHAPES`` | ``static_shapes`` | Set to ``0``/``false`` to disable global static shape specialization. |
259+
| ``HELION_PERSISTENT_RESERVED_SMS`` | ``persistent_reserved_sms`` | Reserve this many streaming multiprocessors when launching persistent kernels (``0`` uses all available SMs). |
254260
| ``HELION_FORCE_AUTOTUNE`` | ``force_autotune`` | Force the autotuner to run even when explicit configs are provided. |
255261
| ``HELION_DISALLOW_AUTOTUNING`` | ``check_autotuning_disabled`` | Hard-disable autotuning; kernels must supply explicit configs when this is ``1``. |
256262
| ``HELION_AUTOTUNE_COMPILE_TIMEOUT`` | ``autotune_compile_timeout`` | Maximum seconds to wait for Triton compilation during autotuning. |

helion/_compiler/program_id.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,11 @@ def __init__(self, is_blocked: bool = False) -> None:
420420
"step": NUM_SM_VAR,
421421
}
422422
if device_function.constexpr_arg(NUM_SM_VAR):
423+
reserved_sms = CompileEnvironment.current().settings.persistent_reserved_sms
424+
reserved_arg = f", reserved_sms={reserved_sms}" if reserved_sms > 0 else ""
423425
device_function.codegen.host_statements.append(
424426
statement_from_string(
425-
f"{NUM_SM_VAR} = helion.runtime.get_num_sm({self.get_device_str()})"
427+
f"{NUM_SM_VAR} = helion.runtime.get_num_sm({self.get_device_str()}{reserved_arg})"
426428
)
427429
)
428430

helion/runtime/__init__.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,40 @@ def set_triton_allocator() -> None:
3838
set_allocator(_alloc_fn)
3939

4040

41-
def get_num_sm(device: torch.device) -> int:
41+
def get_num_sm(device: torch.device, *, reserved_sms: int = 0) -> int:
4242
"""
4343
Get the number of streaming multiprocessors (SMs) for the specified device.
4444
4545
Args:
4646
device: Device to query.
47+
reserved_sms: Number of SMs to keep free for other work (e.g., communication
48+
kernels). Defaults to 0 meaning all device SMs are available to Helion.
4749
4850
Returns:
49-
Grid size to use for a persistent kernel on the device.
51+
Grid size to use for a persistent kernel on the device after accounting
52+
for any reserved SMs. Always at least 1.
5053
"""
5154
assert device.type in ["cuda", "xpu", "cpu"], "TODO: implement for other devices"
55+
available_sms: int
5256
if device.type == "cpu":
5357
try:
5458
num_threads = int(torch.get_num_threads())
5559
except Exception:
5660
num_threads = 0
57-
return num_threads if num_threads > 0 else int(os.cpu_count() or 1)
58-
if device.type == "cuda":
59-
return torch.cuda.get_device_properties(device.index).multi_processor_count
61+
available_sms = num_threads if num_threads > 0 else int(os.cpu_count() or 1)
62+
elif device.type == "cuda":
63+
available_sms = torch.cuda.get_device_properties(
64+
device.index
65+
).multi_processor_count
6066
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number.
61-
return torch.xpu.get_device_properties(device.index).gpu_subslice_count
67+
elif device.type == "xpu":
68+
available_sms = torch.xpu.get_device_properties(device.index).gpu_subslice_count
69+
else:
70+
raise AssertionError("TODO: implement for other devices")
71+
72+
if reserved_sms <= 0:
73+
return available_sms
74+
return max(available_sms - reserved_sms, 1)
6275

6376

6477
def default_launcher(

helion/runtime/settings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@ class _Settings:
278278
static_shapes: bool = dataclasses.field(
279279
default_factory=functools.partial(_env_get_bool, "HELION_STATIC_SHAPES", True)
280280
)
281+
persistent_reserved_sms: int = dataclasses.field(
282+
default_factory=functools.partial(
283+
_env_get_int,
284+
"HELION_PERSISTENT_RESERVED_SMS",
285+
0,
286+
)
287+
)
281288
autotune_log_level: int = dataclasses.field(default_factory=_get_autotune_log_level)
282289
autotune_log: str | None = dataclasses.field(default_factory=_get_autotune_log_path)
283290
autotune_compile_timeout: int = dataclasses.field(
@@ -401,6 +408,10 @@ class Settings(_Settings):
401408
"If True, use static shapes for all tensors. This is a performance optimization. "
402409
"Set HELION_STATIC_SHAPES=0 to disable."
403410
),
411+
"persistent_reserved_sms": (
412+
"Number of streaming multiprocessors to reserve when launching persistent kernels. "
413+
"Set HELION_PERSISTENT_RESERVED_SMS=N (default 0) or pass persistent_reserved_sms=N to helion.kernel."
414+
),
404415
"autotune_log_level": (
405416
"Log level for autotuning using Python logging levels. Default is logging.INFO. "
406417
"Use HELION_AUTOTUNE_LOG_LEVEL to override or set 0 to disable output."

test/test_persistent_kernels.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,27 @@ def simple_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
604604
self.assertIn("helion.runtime.get_num_sm(", code_interleaved)
605605
self.assertIn("for virtual_pid in tl.range", code_interleaved)
606606

607+
def test_persistent_reserved_sms_setting_applies(self):
608+
"""Ensure persistent_reserved_sms is threaded into host code for persistent kernels."""
609+
610+
@helion.kernel(autotune_effort="none", persistent_reserved_sms=3)
611+
def reserved_kernel(x: torch.Tensor) -> torch.Tensor:
612+
out = x.new_empty(x.size())
613+
for tile in hl.tile(x.size(), block_size=[32, 16]):
614+
out[tile] = x[tile]
615+
return out
616+
617+
(x,) = (torch.randn([32, 32], device=DEVICE),)
618+
619+
code_reserved, result_reserved = code_and_output(
620+
reserved_kernel,
621+
(x,),
622+
pid_type="persistent_blocked",
623+
)
624+
625+
torch.testing.assert_close(result_reserved, x)
626+
self.assertIn("reserved_sms=3", code_reserved)
627+
607628
def test_multi_loop_persistent_with_shared_program_id(self):
608629
"""Test that multi-loop persistent kernels with ForEachProgramID work correctly.
609630

0 commit comments

Comments
 (0)