From c9709b9beeb663f5acb1f405594df4cc2f601826 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 13 Nov 2025 20:39:56 -0800 Subject: [PATCH 1/2] test --- test/test_unroll_tuples.expected | 44 ++++++++++++++++++++++++ test/test_unroll_tuples.py | 58 ++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/test/test_unroll_tuples.expected b/test/test_unroll_tuples.expected index 77bdef257..1303d8a57 100644 --- a/test/test_unroll_tuples.expected +++ b/test/test_unroll_tuples.expected @@ -799,6 +799,50 @@ def kernel_static_range_iteration(x: torch.Tensor, *, _launcher=_default_launche # src[test_unroll_tuples.py:N]: return result return result +--- assertExpectedJournal(TestUnrollTuples.test_static_range_tuple_indexing) +from __future__ import annotations + +import torch +import helion.language as hl +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel_static_range_tuple_indexing(buf_tuple_item_0, buf_tuple_item_1, buf_tuple_item_2, buf_tuple_item_3, result, _BLOCK_SIZE_0: tl.constexpr): + # src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_unroll_tuples.py:N]: acc = hl.zeros([tile_m], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) + # src[test_unroll_tuples.py:N]: acc += buf_tuple[i][tile_m] + load = tl.load(buf_tuple_item_0 + indices_0 * 1, None) + v_0 = acc + load + load_1 = tl.load(buf_tuple_item_1 + indices_0 * 1, None) + v_1 = v_0 + load_1 + load_2 = tl.load(buf_tuple_item_2 + indices_0 * 1, None) + v_2 = v_1 + load_2 + load_3 = tl.load(buf_tuple_item_3 + indices_0 * 1, None) + v_3 = v_2 + load_3 + # src[test_unroll_tuples.py:N]: result[tile_m] = acc + tl.store(result + indices_0 * 1, v_3, None) + +def kernel_static_range_tuple_indexing(buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], WORLD_SIZE: hl.constexpr, *, _launcher=_default_launcher): + """Test tuple indexing with static_range - iterating over tuple elements.""" + # src[test_unroll_tuples.py:N]: (M,) = buf_tuple[0].shape + M, = buf_tuple[0].shape + # src[test_unroll_tuples.py:N]: result = torch.zeros_like(buf_tuple[0]) + result = torch.zeros_like(buf_tuple[0]) + # src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M): + _BLOCK_SIZE_0 = 32 + # src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M): + # src[test_unroll_tuples.py:N]: acc = hl.zeros([tile_m], dtype=torch.float32) + # src[test_unroll_tuples.py:N-N]: ... + _launcher(_helion_kernel_static_range_tuple_indexing, (triton.cdiv(32, _BLOCK_SIZE_0),), buf_tuple[0], buf_tuple[1], buf_tuple[2], buf_tuple[3], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_unroll_tuples.py:N]: return result + return result + --- assertExpectedJournal(TestUnrollTuples.test_static_range_with_start) from __future__ import annotations diff --git a/test/test_unroll_tuples.py b/test/test_unroll_tuples.py index 8d462db27..c100917b6 100644 --- a/test/test_unroll_tuples.py +++ b/test/test_unroll_tuples.py @@ -5,6 +5,7 @@ import torch import helion +from helion import exc from helion._testing import DEVICE from helion._testing import RefEagerTestBase from helion._testing import TestCase @@ -480,6 +481,63 @@ def test_static_range_with_start(self): expected = x * 9 torch.testing.assert_close(result, expected) + def test_static_range_tuple_indexing(self): + @helion.kernel(autotune_effort="none") + def kernel_static_range_tuple_indexing( + buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + WORLD_SIZE: hl.constexpr, + ) -> torch.Tensor: + """Test tuple indexing with static_range - iterating over tuple elements.""" + (M,) = buf_tuple[0].shape + result = torch.zeros_like(buf_tuple[0]) + + for tile_m in hl.tile(M): + acc = hl.zeros([tile_m], dtype=torch.float32) + + # Use static_range to index into tuple + for i in hl.static_range(WORLD_SIZE): + acc += buf_tuple[i][tile_m] + + result[tile_m] = acc + + return result + + size = (32,) + world_size = 4 + + tensors = tuple( + torch.ones(size, device=DEVICE, dtype=torch.float32) * (i + 1) + for i in range(world_size) + ) + + code, result = code_and_output( + kernel_static_range_tuple_indexing, (tensors, world_size) + ) + + self.assertExpectedJournal(code) + + # Test correctness - should be sum of all tensors: 1 + 2 + 3 + 4 = 10 + expected = sum(tensors) + torch.testing.assert_close(result, expected) + + def test_static_range_tuple_indexing_requires_uniform_types(self): + @helion.kernel(autotune_effort="none") + def kernel_static_range_tuple_mismatch(x: torch.Tensor) -> torch.Tensor: + heterogeneous = (x, 1) + + for _tile_n in hl.tile(x.size(0)): + for idx in hl.static_range(2): + _ = heterogeneous[idx] + return x + + x = torch.ones((8,), device=DEVICE) + + with self.assertRaisesRegex( + exc.TypeInferenceError, + r"Tuple indexing with non-literal index requires all elements to have the same type", + ): + code_and_output(kernel_static_range_tuple_mismatch, (x,)) + def test_mixed_constants_and_tensors(self): """Test mixed iteration over both tensors and constants.""" size = (22,) From 2dca0546fc2c246e290a608cfcaabb9fa5d58c13 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 13 Nov 2025 20:40:01 -0800 Subject: [PATCH 2/2] fix --- helion/_compiler/device_ir.py | 5 +++-- helion/_compiler/type_propagation.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 4b5bdfdb9..6e5ca394f 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -1166,8 +1166,9 @@ def visit_Subscript(self, node: ast.Subscript) -> object: assert isinstance(value, ExtendedAST) type_info = value._type_info if isinstance(type_info, SequenceType): - if isinstance(node.slice, ast.Constant): - return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue] + index_value = self.visit(node.slice) + if isinstance(index_value, int): + return self.visit(value)[index_value] # pyright: ignore[reportIndexIssue] raise exc.InvalidSequenceSubscription(node.slice) if isinstance(type_info, StackTensorType): return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 11ec387d6..8ae1e423e 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -1273,6 +1273,17 @@ def populate_symbol_origins(self, origin: Origin) -> None: subtype.populate_symbol_origins(GetItemOrigin(origin, i)) def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: + # Tuple indexing with non-literal indices (e.g., from hl.static_range) + if self.python_type is tuple and isinstance(key, SymIntType): + if not self.element_types: + raise exc.TypeInferenceError("Cannot index empty tuple") + first_type = self.element_types[0] + if not all(type(e) is type(first_type) for e in self.element_types[1:]): + raise exc.TypeInferenceError( + "Tuple indexing with non-literal index requires all elements to have the same type" + ) + return first_type + return super().propagate_getitem(key, origin) def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo: