Skip to content

Commit 7acbc82

Browse files
authored
Fix scalar broadcast bug in inductor lowering (#1159)
1 parent 4cd42b5 commit 7acbc82

File tree

5 files changed

+178
-179
lines changed

5 files changed

+178
-179
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,10 @@ def input_asts(self, ctx: LoweringContext, node: torch.fx.Node) -> list[ast.AST]
358358
def visit(n: torch.fx.Node) -> None:
359359
ast_val = cast("ast.AST", ctx.env[n])
360360
if isinstance(fake_val := n.meta["val"], torch.Tensor):
361-
if fake_val.ndim < ndim:
362-
# Broadcast to force ranks to match
361+
# Don't expand scalars (0-D tensors) - let Triton handle broadcasting naturally
362+
# Expanding scalars with [None, None] creates incorrect broadcast shapes
363+
if fake_val.ndim < ndim and fake_val.ndim > 0:
364+
# Broadcast to force ranks to match (but only for non-scalar tensors)
363365
expand = ["None"] * (ndim - fake_val.ndim) + [":"] * fake_val.ndim
364366
ast_val = expr_from_string(
365367
"{tensor}[" + ", ".join(expand) + "]", tensor=ast_val

test/test_control_flow.expected

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,18 @@ def _helion_mul_relu_block_backward_kernel(x, y, dz, dx, dy, _BLOCK_SIZE_0: tl.c
102102
# src[test_control_flow.py:N]: relu_grad = torch.where(relu_mask, 1, 0)
103103
v_3 = tl.full([], 0, tl.int64)
104104
v_4 = tl.full([], 1, tl.int64)
105-
v_5 = v_4[None, None]
106-
v_6 = v_3[None, None]
107-
v_7 = tl.where(v_2, v_5, v_6)
105+
v_5 = tl.where(v_2, v_4, v_3)
108106
# src[test_control_flow.py:N]: dx[tile_i, tile_j] = dz_tile * relu_grad * y_tile[:, None]
109-
v_8 = tl.cast(v_7, tl.float32)
110-
v_9 = dz_tile * v_8
107+
v_6 = tl.cast(v_5, tl.float32)
108+
v_7 = dz_tile * v_6
111109
subscript_1 = y_tile[:, None]
112-
v_10 = v_9 * subscript_1
113-
tl.store(dx + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_10, None)
110+
v_8 = v_7 * subscript_1
111+
tl.store(dx + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_8, None)
114112
# src[test_control_flow.py:N]: local_dy_grad = torch.sum(dz_tile * relu_grad * x_tile, dim=1)
115-
v_11 = tl.cast(v_7, tl.float32)
116-
v_12 = dz_tile * v_11
117-
v_13 = v_12 * x_tile
118-
local_dy_grad = tl.cast(tl.sum(v_13, 1), tl.float32)
113+
v_9 = tl.cast(v_5, tl.float32)
114+
v_10 = dz_tile * v_9
115+
v_11 = v_10 * x_tile
116+
local_dy_grad = tl.cast(tl.sum(v_11, 1), tl.float32)
119117
# src[test_control_flow.py:N]: hl.atomic_add(dy, [tile_i], local_dy_grad)
120118
tl.atomic_add(dy + indices_0 * 1, local_dy_grad, mask=None, sem='relaxed')
121119

0 commit comments

Comments
 (0)