@@ -799,6 +799,50 @@ def kernel_static_range_iteration(x: torch.Tensor, *, _launcher=_default_launche
799799 # src[test_unroll_tuples.py:N]: return result
800800 return result
801801
802+ --- assertExpectedJournal(TestUnrollTuples.test_static_range_tuple_indexing)
803+ from __future__ import annotations
804+
805+ import torch
806+ import helion.language as hl
807+ import triton
808+ import triton.language as tl
809+ from helion.runtime import default_launcher as _default_launcher
810+
811+ @triton.jit
812+ def _helion_kernel_static_range_tuple_indexing(buf_tuple_item_0, buf_tuple_item_1, buf_tuple_item_2, buf_tuple_item_3, result, _BLOCK_SIZE_0: tl.constexpr):
813+ # src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
814+ pid_0 = tl.program_id(0)
815+ offset_0 = pid_0 * _BLOCK_SIZE_0
816+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
817+ # src[test_unroll_tuples.py:N]: acc = hl.zeros([tile_m], dtype=torch.float32)
818+ acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
819+ # src[test_unroll_tuples.py:N]: acc += buf_tuple[i][tile_m]
820+ load = tl.load(buf_tuple_item_0 + indices_0 * 1, None)
821+ v_0 = acc + load
822+ load_1 = tl.load(buf_tuple_item_1 + indices_0 * 1, None)
823+ v_1 = v_0 + load_1
824+ load_2 = tl.load(buf_tuple_item_2 + indices_0 * 1, None)
825+ v_2 = v_1 + load_2
826+ load_3 = tl.load(buf_tuple_item_3 + indices_0 * 1, None)
827+ v_3 = v_2 + load_3
828+ # src[test_unroll_tuples.py:N]: result[tile_m] = acc
829+ tl.store(result + indices_0 * 1, v_3, None)
830+
831+ def kernel_static_range_tuple_indexing(buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], WORLD_SIZE: hl.constexpr, *, _launcher=_default_launcher):
832+ """Test tuple indexing with static_range - iterating over tuple elements."""
833+ # src[test_unroll_tuples.py:N]: (M,) = buf_tuple[0].shape
834+ M, = buf_tuple[0].shape
835+ # src[test_unroll_tuples.py:N]: result = torch.zeros_like(buf_tuple[0])
836+ result = torch.zeros_like(buf_tuple[0])
837+ # src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
838+ _BLOCK_SIZE_0 = 32
839+ # src[test_unroll_tuples.py:N]: for tile_m in hl.tile(M):
840+ # src[test_unroll_tuples.py:N]: acc = hl.zeros([tile_m], dtype=torch.float32)
841+ # src[test_unroll_tuples.py:N-N]: ...
842+ _launcher(_helion_kernel_static_range_tuple_indexing, (triton.cdiv(32, _BLOCK_SIZE_0),), buf_tuple[0], buf_tuple[1], buf_tuple[2], buf_tuple[3], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
843+ # src[test_unroll_tuples.py:N]: return result
844+ return result
845+
802846--- assertExpectedJournal(TestUnrollTuples.test_static_range_with_start)
803847from __future__ import annotations
804848
0 commit comments