Skip to content

Commit a30ce01

Browse files
authored
Support tuple indexing by hl.static_range iterator (#1134)
1 parent 5872294 commit a30ce01

File tree

4 files changed

+116
-2
lines changed

4 files changed

+116
-2
lines changed

helion/_compiler/device_ir.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,8 +1166,9 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
11661166
assert isinstance(value, ExtendedAST)
11671167
type_info = value._type_info
11681168
if isinstance(type_info, SequenceType):
1169-
if isinstance(node.slice, ast.Constant):
1170-
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
1169+
index_value = self.visit(node.slice)
1170+
if isinstance(index_value, int):
1171+
return self.visit(value)[index_value] # pyright: ignore[reportIndexIssue]
11711172
raise exc.InvalidSequenceSubscription(node.slice)
11721173
if isinstance(type_info, StackTensorType):
11731174
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]

helion/_compiler/type_propagation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,17 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12731273
subtype.populate_symbol_origins(GetItemOrigin(origin, i))
12741274

12751275
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1276+
# Tuple indexing with non-literal indices (e.g., from hl.static_range)
1277+
if self.python_type is tuple and isinstance(key, SymIntType):
1278+
if not self.element_types:
1279+
raise exc.TypeInferenceError("Cannot index empty tuple")
1280+
first_type = self.element_types[0]
1281+
if not all(type(e) is type(first_type) for e in self.element_types[1:]):
1282+
raise exc.TypeInferenceError(
1283+
"Tuple indexing with non-literal index requires all elements to have the same type"
1284+
)
1285+
return first_type
1286+
12761287
return super().propagate_getitem(key, origin)
12771288

12781289
def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo:

test/test_unroll_tuples.expected

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,50 @@ def kernel_static_range_iteration(x: torch.Tensor, *, _launcher=_default_launche
799799
# src[test_unroll_tuples.py:N]: return result
800800
return result
801801

802+
--- assertExpectedJournal(TestUnrollTuples.test_static_range_tuple_indexing)
803+
from __future__ import annotations
804+
805+
import torch
806+
import helion.language as hl
807+
import triton
808+
import triton.language as tl
809+
from helion.runtime import default_launcher as _default_launcher
810+
811+
@triton.jit
812+
def _helion_kernel_static_range_tuple_indexing(buf_tuple_item_0, buf_tuple_item_1, buf_tuple_item_2, buf_tuple_item_3, result, _BLOCK_SIZE_0: tl.constexpr):
813+
# src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
814+
pid_0 = tl.program_id(0)
815+
offset_0 = pid_0 * _BLOCK_SIZE_0
816+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
817+
# src[test_unroll_tuples.py:N]: acc = hl.zeros([tile_m], dtype=torch.float32)
818+
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
819+
# src[test_unroll_tuples.py:N]: acc += buf_tuple[i][tile_m]
820+
load = tl.load(buf_tuple_item_0 + indices_0 * 1, None)
821+
v_0 = acc + load
822+
load_1 = tl.load(buf_tuple_item_1 + indices_0 * 1, None)
823+
v_1 = v_0 + load_1
824+
load_2 = tl.load(buf_tuple_item_2 + indices_0 * 1, None)
825+
v_2 = v_1 + load_2
826+
load_3 = tl.load(buf_tuple_item_3 + indices_0 * 1, None)
827+
v_3 = v_2 + load_3
828+
# src[test_unroll_tuples.py:N]: result[tile_m] = acc
829+
tl.store(result + indices_0 * 1, v_3, None)
830+
831+
def kernel_static_range_tuple_indexing(buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], WORLD_SIZE: hl.constexpr, *, _launcher=_default_launcher):
832+
"""Test tuple indexing with static_range - iterating over tuple elements."""
833+
# src[test_unroll_tuples.py:N]: (M,) = buf_tuple[0].shape
834+
M, = buf_tuple[0].shape
835+
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(buf_tuple[0])
836+
result = torch.zeros_like(buf_tuple[0])
837+
# src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
838+
_BLOCK_SIZE_0 = 32
839+
# src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
840+
# src[test_unroll_tuples.py:N]: acc = hl.zeros([tile_m], dtype=torch.float32)
841+
# src[test_unroll_tuples.py:N-N]: ...
842+
_launcher(_helion_kernel_static_range_tuple_indexing, (triton.cdiv(32, _BLOCK_SIZE_0),), buf_tuple[0], buf_tuple[1], buf_tuple[2], buf_tuple[3], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
843+
# src[test_unroll_tuples.py:N]: return result
844+
return result
845+
802846
--- assertExpectedJournal(TestUnrollTuples.test_static_range_with_start)
803847
from __future__ import annotations
804848

test/test_unroll_tuples.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import helion
8+
from helion import exc
89
from helion._testing import DEVICE
910
from helion._testing import RefEagerTestBase
1011
from helion._testing import TestCase
@@ -480,6 +481,63 @@ def test_static_range_with_start(self):
480481
expected = x * 9
481482
torch.testing.assert_close(result, expected)
482483

484+
def test_static_range_tuple_indexing(self):
485+
@helion.kernel(autotune_effort="none")
486+
def kernel_static_range_tuple_indexing(
487+
buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
488+
WORLD_SIZE: hl.constexpr,
489+
) -> torch.Tensor:
490+
"""Test tuple indexing with static_range - iterating over tuple elements."""
491+
(M,) = buf_tuple[0].shape
492+
result = torch.zeros_like(buf_tuple[0])
493+
494+
for tile_m in hl.tile(M):
495+
acc = hl.zeros([tile_m], dtype=torch.float32)
496+
497+
# Use static_range to index into tuple
498+
for i in hl.static_range(WORLD_SIZE):
499+
acc += buf_tuple[i][tile_m]
500+
501+
result[tile_m] = acc
502+
503+
return result
504+
505+
size = (32,)
506+
world_size = 4
507+
508+
tensors = tuple(
509+
torch.ones(size, device=DEVICE, dtype=torch.float32) * (i + 1)
510+
for i in range(world_size)
511+
)
512+
513+
code, result = code_and_output(
514+
kernel_static_range_tuple_indexing, (tensors, world_size)
515+
)
516+
517+
self.assertExpectedJournal(code)
518+
519+
# Test correctness - should be sum of all tensors: 1 + 2 + 3 + 4 = 10
520+
expected = sum(tensors)
521+
torch.testing.assert_close(result, expected)
522+
523+
def test_static_range_tuple_indexing_requires_uniform_types(self):
524+
@helion.kernel(autotune_effort="none")
525+
def kernel_static_range_tuple_mismatch(x: torch.Tensor) -> torch.Tensor:
526+
heterogeneous = (x, 1)
527+
528+
for _tile_n in hl.tile(x.size(0)):
529+
for idx in hl.static_range(2):
530+
_ = heterogeneous[idx]
531+
return x
532+
533+
x = torch.ones((8,), device=DEVICE)
534+
535+
with self.assertRaisesRegex(
536+
exc.TypeInferenceError,
537+
r"Tuple indexing with non-literal index requires all elements to have the same type",
538+
):
539+
code_and_output(kernel_static_range_tuple_mismatch, (x,))
540+
483541
def test_mixed_constants_and_tensors(self):
484542
"""Test mixed iteration over both tensors and constants."""
485543
size = (22,)

0 commit comments

Comments
 (0)