Skip to content

Commit 4cd42b5

Browse files
authored
Fix min hoisting bug (#1157)
1 parent a9c72c7 commit 4cd42b5

File tree

4 files changed

+68
-8
lines changed

4 files changed

+68
-8
lines changed

helion/_compiler/compile_environment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .source_location import SourceLocation
2626
from .source_location import current_location
2727
from .variable_origin import BlockSizeOrigin
28+
from .variable_origin import GridOrigin
2829
from .variable_origin import Origin
2930

3031
if TYPE_CHECKING:
@@ -453,7 +454,7 @@ def get_block_id(self, size: int | torch.SymInt | sympy.Basic) -> int | None:
453454
Get the block ID associated with a given size expression.
454455
455456
This method determines if a size expression corresponds to a registered block size
456-
in the current compilation environment. It looks up the origin information of
457+
or grid index in the current compilation environment. It looks up the origin information of
457458
symbolic expressions to find their associated block IDs.
458459
459460
Args:
@@ -470,7 +471,7 @@ def get_block_id(self, size: int | torch.SymInt | sympy.Basic) -> int | None:
470471
origin_info = HostFunction.current().expr_to_origin.get(size)
471472
if origin_info is not None and isinstance(
472473
origin_info.origin,
473-
BlockSizeOrigin,
474+
(BlockSizeOrigin, GridOrigin),
474475
):
475476
return origin_info.origin.block_id
476477
return None

helion/_compiler/device_function.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,12 +414,16 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str:
414414

415415
def user_sympy_expr(self, expr: sympy.Expr) -> str:
416416
"""A sympy expression that flows into user computations."""
417+
expr_to_origin = HostFunction.current().expr_to_origin
417418
replacements = {}
418419
for sym in sorted(expr.free_symbols, key=lambda s: s.name):
419420
assert isinstance(sym, sympy.Symbol)
420-
block_idx = CompileEnvironment.current().get_block_id(sym)
421-
if block_idx is not None:
422-
replacements[sym] = self.tile_strategy.user_size(block_idx)
421+
origin_info = expr_to_origin.get(sym)
422+
if origin_info is None:
423+
continue
424+
origin = origin_info.origin
425+
if isinstance(origin, BlockSizeOrigin):
426+
replacements[sym] = self.tile_strategy.user_size(origin.block_id)
423427
if replacements:
424428
# pyrefly: ignore [bad-assignment]
425429
expr = expr.xreplace(replacements)

helion/language/_tracing_ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .._compiler.ast_extension import statement_from_string
1717
from .._compiler.compile_environment import CompileEnvironment
1818
from .._compiler.host_function import HostFunction
19+
from .._compiler.variable_origin import BlockSizeOrigin
1920
from ..exc import NotInsideKernel
2021
from . import _decorators
2122
from .tile_proxy import Tile
@@ -50,13 +51,16 @@ def _(state: CodegenState) -> ast.AST:
5051
return expr_from_string(str(val))
5152

5253
assert isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)), val
54+
sym_expr = val._sympy_()
55+
origin_info = HostFunction.current().expr_to_origin.get(sym_expr)
5356
# pyrefly: ignore [bad-argument-type]
54-
if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None:
55-
block_size_var = state.device_function.block_size_var(block_idx)
57+
if origin_info is not None and isinstance(origin_info.origin, BlockSizeOrigin):
58+
block_size_var = state.device_function.block_size_var(
59+
origin_info.origin.block_id
60+
)
5661
if block_size_var is None:
5762
return expr_from_string("1")
5863
return expr_from_string(block_size_var)
59-
sym_expr = val._sympy_()
6064
return state.codegen.lift_symnode(
6165
expr_from_string(state.sympy_expr(sym_expr)),
6266
sym_expr,

test/test_misc.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,57 @@ def kernel_with_duplicate_refs(x: torch.Tensor) -> torch.Tensor:
5454
code, result = code_and_output(kernel_with_duplicate_refs, (x,))
5555
torch.testing.assert_close(result, expected)
5656

57+
@skipIfRefEager("block_size=1 doesn't work in ref eager mode")
58+
def test_min_hoist(self):
59+
"""Test case to reproduce issue #1155: offsets are hoisted out of loops"""
60+
61+
@helion.kernel(autotune_effort="none")
62+
def kernel(
63+
k: torch.Tensor,
64+
w: torch.Tensor,
65+
u: torch.Tensor,
66+
g: torch.Tensor,
67+
chunk_size: int,
68+
) -> torch.Tensor:
69+
batch, seqlen, nheads = g.shape
70+
dstate = u.shape[-1]
71+
chunk_size = hl.specialize(chunk_size)
72+
nchunks = (seqlen + chunk_size - 1) // chunk_size
73+
out = torch.empty(
74+
(batch, nchunks, nheads, dstate), device=g.device, dtype=g.dtype
75+
)
76+
block_v = hl.register_block_size(dstate)
77+
for tile_b, tile_h, tile_v in hl.tile(
78+
[batch, nheads, dstate], block_size=[1, 1, block_v]
79+
):
80+
for t_i in hl.tile(seqlen, block_size=chunk_size):
81+
last = min(t_i.begin + chunk_size - 1, seqlen - 1)
82+
g_scalar = g[tile_b.begin, last, tile_h.begin]
83+
out[tile_b.begin, t_i.id, tile_h.begin, tile_v] = (
84+
g_scalar + hl.zeros([tile_v], dtype=g.dtype)
85+
)
86+
return out
87+
88+
batch, seqlen, nheads, dhead, dstate = 1, 10, 1, 1, 2
89+
chunk_size = 4
90+
k = torch.zeros(
91+
batch, seqlen, nheads, dhead, device=DEVICE, dtype=torch.float32
92+
)
93+
w = torch.zeros_like(k)
94+
u = torch.zeros(
95+
batch, seqlen, nheads, dstate, device=DEVICE, dtype=torch.float32
96+
)
97+
g = torch.arange(seqlen, device=DEVICE, dtype=torch.float32).view(
98+
batch, seqlen, nheads
99+
)
100+
101+
expected = torch.tensor(
102+
[[[[3, 3]], [[7, 7]], [[9, 9]]]], device=DEVICE, dtype=torch.float32
103+
)
104+
105+
result = kernel(k, w, u, g, chunk_size)
106+
torch.testing.assert_close(result, expected)
107+
57108
def test_torch_alloc(self):
58109
@helion.kernel(config={"block_sizes": [64, 64]})
59110
def fn(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)