Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,8 +1166,9 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
assert isinstance(value, ExtendedAST)
type_info = value._type_info
if isinstance(type_info, SequenceType):
if isinstance(node.slice, ast.Constant):
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
index_value = self.visit(node.slice)
if isinstance(index_value, int):
return self.visit(value)[index_value] # pyright: ignore[reportIndexIssue]
raise exc.InvalidSequenceSubscription(node.slice)
if isinstance(type_info, StackTensorType):
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
Expand Down
11 changes: 11 additions & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,17 @@ def populate_symbol_origins(self, origin: Origin) -> None:
subtype.populate_symbol_origins(GetItemOrigin(origin, i))

def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
# Tuple indexing with non-literal indices (e.g., from hl.static_range)
if self.python_type is tuple and isinstance(key, SymIntType):
if not self.element_types:
raise exc.TypeInferenceError("Cannot index empty tuple")
first_type = self.element_types[0]
if not all(type(e) is type(first_type) for e in self.element_types[1:]):
raise exc.TypeInferenceError(
"Tuple indexing with non-literal index requires all elements to have the same type"
)
return first_type

return super().propagate_getitem(key, origin)

def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo:
Expand Down
44 changes: 44 additions & 0 deletions test/test_unroll_tuples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,50 @@ def kernel_static_range_iteration(x: torch.Tensor, *, _launcher=_default_launche
# src[test_unroll_tuples.py:N]: return result
return result

--- assertExpectedJournal(TestUnrollTuples.test_static_range_tuple_indexing)
from __future__ import annotations

import torch
import helion.language as hl
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel_static_range_tuple_indexing(buf_tuple_item_0, buf_tuple_item_1, buf_tuple_item_2, buf_tuple_item_3, result, _BLOCK_SIZE_0: tl.constexpr):
# src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
# src[test_unroll_tuples.py:N]: acc = hl.zeros([tile_m], dtype=torch.float32)
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
# src[test_unroll_tuples.py:N]: acc += buf_tuple[i][tile_m]
load = tl.load(buf_tuple_item_0 + indices_0 * 1, None)
v_0 = acc + load
load_1 = tl.load(buf_tuple_item_1 + indices_0 * 1, None)
v_1 = v_0 + load_1
load_2 = tl.load(buf_tuple_item_2 + indices_0 * 1, None)
v_2 = v_1 + load_2
load_3 = tl.load(buf_tuple_item_3 + indices_0 * 1, None)
v_3 = v_2 + load_3
# src[test_unroll_tuples.py:N]: result[tile_m] = acc
tl.store(result + indices_0 * 1, v_3, None)

def kernel_static_range_tuple_indexing(buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], WORLD_SIZE: hl.constexpr, *, _launcher=_default_launcher):
"""Test tuple indexing with static_range - iterating over tuple elements."""
# src[test_unroll_tuples.py:N]: (M,) = buf_tuple[0].shape
M, = buf_tuple[0].shape
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(buf_tuple[0])
result = torch.zeros_like(buf_tuple[0])
# src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
_BLOCK_SIZE_0 = 32
# src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
# src[test_unroll_tuples.py:N]: acc = hl.zeros([tile_m], dtype=torch.float32)
# src[test_unroll_tuples.py:N-N]: ...
_launcher(_helion_kernel_static_range_tuple_indexing, (triton.cdiv(32, _BLOCK_SIZE_0),), buf_tuple[0], buf_tuple[1], buf_tuple[2], buf_tuple[3], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
# src[test_unroll_tuples.py:N]: return result
return result

--- assertExpectedJournal(TestUnrollTuples.test_static_range_with_start)
from __future__ import annotations

Expand Down
58 changes: 58 additions & 0 deletions test/test_unroll_tuples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

import helion
from helion import exc
from helion._testing import DEVICE
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
Expand Down Expand Up @@ -480,6 +481,63 @@ def test_static_range_with_start(self):
expected = x * 9
torch.testing.assert_close(result, expected)

def test_static_range_tuple_indexing(self):
@helion.kernel(autotune_effort="none")
def kernel_static_range_tuple_indexing(
buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
WORLD_SIZE: hl.constexpr,
) -> torch.Tensor:
"""Test tuple indexing with static_range - iterating over tuple elements."""
(M,) = buf_tuple[0].shape
result = torch.zeros_like(buf_tuple[0])

for tile_m in hl.tile(M):
acc = hl.zeros([tile_m], dtype=torch.float32)

# Use static_range to index into tuple
for i in hl.static_range(WORLD_SIZE):
acc += buf_tuple[i][tile_m]

result[tile_m] = acc

return result

size = (32,)
world_size = 4

tensors = tuple(
torch.ones(size, device=DEVICE, dtype=torch.float32) * (i + 1)
for i in range(world_size)
)

code, result = code_and_output(
kernel_static_range_tuple_indexing, (tensors, world_size)
)

self.assertExpectedJournal(code)

# Test correctness - should be sum of all tensors: 1 + 2 + 3 + 4 = 10
expected = sum(tensors)
torch.testing.assert_close(result, expected)

def test_static_range_tuple_indexing_requires_uniform_types(self):
@helion.kernel(autotune_effort="none")
def kernel_static_range_tuple_mismatch(x: torch.Tensor) -> torch.Tensor:
heterogeneous = (x, 1)

for _tile_n in hl.tile(x.size(0)):
for idx in hl.static_range(2):
_ = heterogeneous[idx]
return x

x = torch.ones((8,), device=DEVICE)

with self.assertRaisesRegex(
exc.TypeInferenceError,
r"Tuple indexing with non-literal index requires all elements to have the same type",
):
code_and_output(kernel_static_range_tuple_mismatch, (x,))

def test_mixed_constants_and_tensors(self):
"""Test mixed iteration over both tensors and constants."""
size = (22,)
Expand Down
Loading