Skip to content

Commit b5430fe

Browse files
authored
Add warning to discourage use of acc += lhs @ rhs pattern (#1111)
1 parent d44009f commit b5430fe

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

docs/api/exceptions.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@ Warnings can be suppressed by including them in the `ignore_warnings` setting:
350350
351351
Warns when operations return tensors on wrong device.
352352
353+
.. autoclass:: TiledKMatmulAccumulationWarning
354+
355+
Warns when ``acc += lhs @ rhs`` pattern is used inside tiled device loops.
356+
353357
```
354358

355359
### Warning Suppression

helion/_compiler/type_propagation.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,32 @@ def generic_visit(self, node: ast.AST) -> TypeInfo:
16461646
super().generic_visit(node)
16471647
raise exc.UnsupportedPythonType(f"ast.{node.__class__.__name__}")
16481648

1649+
@staticmethod
1650+
def _contains_matmul(node: ast.AST | None) -> bool:
1651+
if node is None:
1652+
return False
1653+
1654+
matmul_functions = ["torch.matmul", "torch.mm", "torch.bmm", "hl.dot"]
1655+
1656+
for sub_node in ast.walk(node):
1657+
# Check for @ operator
1658+
if isinstance(sub_node, ast.BinOp) and isinstance(sub_node.op, ast.MatMult):
1659+
return True
1660+
1661+
# Check for function calls
1662+
if not isinstance(sub_node, ast.Call):
1663+
continue
1664+
1665+
func = sub_node.func
1666+
1667+
# Check for matmul function calls
1668+
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
1669+
qualified_name = f"{func.value.id}.{func.attr}"
1670+
if qualified_name in matmul_functions:
1671+
return True
1672+
1673+
return False
1674+
16491675
def _bool_op(self, op: ast.boolop, left: TypeInfo, right: TypeInfo) -> TypeInfo:
16501676
try:
16511677
val = left.truth_value()
@@ -2094,6 +2120,12 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> TypeInfo:
20942120

20952121
def visit_AugAssign(self, node: ast.AugAssign) -> TypeInfo:
20962122
assert isinstance(node.target, ExtendedAST)
2123+
if (
2124+
self.device_loop_depth > 0
2125+
and isinstance(node.op, ast.Add)
2126+
and self._contains_matmul(node.value)
2127+
):
2128+
warning(exc.TiledKMatmulAccumulationWarning)
20972129
try:
20982130
type_info = self.visit(
20992131
create(

helion/exc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,22 @@ class BlockSizeIgnoredInInterpretMode(BaseWarning):
414414
message = "block_size is specified to be {0}, but in interpret mode, the full dimension size is always used."
415415

416416

417+
class TiledKMatmulAccumulationWarning(BaseWarning):
418+
message = (
419+
"Detected one of the following usage patterns inside a Helion device loop:\n"
420+
"- `acc += lhs @ rhs`\n"
421+
"- `acc += torch.matmul(lhs, rhs)`\n"
422+
"- `acc += torch.mm(lhs, rhs)`\n"
423+
"- `acc += torch.bmm(lhs, rhs)`\n"
424+
"- `acc += hl.dot(lhs, rhs)`\n"
425+
"For accurate numerics, please use one of:\n"
426+
"- `torch.addmm(acc, ...)`\n"
427+
"- `torch.baddbmm(acc, ...)`\n"
428+
"- `hl.dot(acc=...)`\n"
429+
"to accumulate across tiled-K iterations of a matmul operation."
430+
)
431+
432+
417433
class AutotuningDisallowedInEnvironment(BaseError):
418434
message = "Autotuning is disabled {0}, please provide a config to @helion.kernel via the config= argument."
419435

test/test_dot.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import contextlib
4+
import io
35
import itertools
46
from typing import Callable
57
import unittest
@@ -288,6 +290,141 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
288290
expected = torch.bmm(A, B).to(result.dtype) * 2
289291
torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2)
290292

293+
def _assert_warning_in_stderr(
294+
self, kernel, args, expected_result, warning_str, *, atol=1e-2, rtol=1e-2
295+
):
296+
stderr_buffer = io.StringIO()
297+
with contextlib.redirect_stderr(stderr_buffer):
298+
_, out = code_and_output(kernel, args)
299+
300+
torch.testing.assert_close(out, expected_result, atol=atol, rtol=rtol)
301+
302+
warning_text = stderr_buffer.getvalue()
303+
self.assertIn(warning_str, warning_text)
304+
305+
@skipIfRefEager("Warning emitted in compile mode only")
306+
def test_augassign_at_operator_warning(self):
307+
@helion.kernel(static_shapes=True)
308+
def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
309+
m, k = x.shape
310+
k2, n = y.shape
311+
assert k == k2
312+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
313+
for tile_m, tile_n in hl.tile([m, n]):
314+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
315+
for tile_k in hl.tile(k):
316+
lhs = x[tile_m, tile_k]
317+
rhs = y[tile_k, tile_n]
318+
acc += lhs @ rhs
319+
out[tile_m, tile_n] = acc
320+
return out
321+
322+
x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32)
323+
y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32)
324+
325+
self._assert_warning_in_stderr(
326+
warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]"
327+
)
328+
329+
@skipIfRefEager("Warning emitted in compile mode only")
330+
def test_augassign_torch_matmul_warning(self):
331+
@helion.kernel(static_shapes=True)
332+
def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
333+
m, k = x.shape
334+
k2, n = y.shape
335+
assert k == k2
336+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
337+
for tile_m, tile_n in hl.tile([m, n]):
338+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
339+
for tile_k in hl.tile(k):
340+
lhs = x[tile_m, tile_k]
341+
rhs = y[tile_k, tile_n]
342+
acc += torch.matmul(lhs, rhs)
343+
out[tile_m, tile_n] = acc
344+
return out
345+
346+
x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32)
347+
y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32)
348+
349+
self._assert_warning_in_stderr(
350+
warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]"
351+
)
352+
353+
@skipIfRefEager("Warning emitted in compile mode only")
354+
def test_augassign_torch_mm_warning(self):
355+
@helion.kernel(static_shapes=True)
356+
def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
357+
m, k = x.shape
358+
k2, n = y.shape
359+
assert k == k2
360+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
361+
for tile_m, tile_n in hl.tile([m, n]):
362+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
363+
for tile_k in hl.tile(k):
364+
lhs = x[tile_m, tile_k]
365+
rhs = y[tile_k, tile_n]
366+
acc += torch.mm(lhs, rhs)
367+
out[tile_m, tile_n] = acc
368+
return out
369+
370+
x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32)
371+
y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32)
372+
373+
self._assert_warning_in_stderr(
374+
warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]"
375+
)
376+
377+
@skipIfRefEager("Warning emitted in compile mode only")
378+
def test_augassign_torch_bmm_warning(self):
379+
@helion.kernel(static_shapes=True)
380+
def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
381+
b, m, k = x.shape
382+
b2, k2, n = y.shape
383+
assert b == b2 and k == k2
384+
out = torch.empty([b, m, n], dtype=x.dtype, device=x.device)
385+
for tile_b, tile_m, tile_n in hl.tile([b, m, n]):
386+
acc = hl.zeros([tile_b, tile_m, tile_n], dtype=torch.float32)
387+
for tile_k in hl.tile(k):
388+
lhs = x[tile_b, tile_m, tile_k]
389+
rhs = y[tile_b, tile_k, tile_n]
390+
acc += torch.bmm(lhs, rhs)
391+
out[tile_b, tile_m, tile_n] = acc
392+
return out
393+
394+
x = torch.randn(4, 32, 16, device=DEVICE, dtype=torch.float32)
395+
y = torch.randn(4, 16, 32, device=DEVICE, dtype=torch.float32)
396+
397+
self._assert_warning_in_stderr(
398+
warn_kernel,
399+
(x, y),
400+
torch.bmm(x, y),
401+
"WARNING[TiledKMatmulAccumulationWarning]",
402+
)
403+
404+
@skipIfRefEager("Warning emitted in compile mode only")
405+
def test_augassign_hl_dot_warning(self):
406+
@helion.kernel(static_shapes=True)
407+
def no_warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
408+
m, k = x.shape
409+
k2, n = y.shape
410+
assert k == k2
411+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
412+
for tile_m, tile_n in hl.tile([m, n]):
413+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
414+
for tile_k in hl.tile(k):
415+
lhs = x[tile_m, tile_k]
416+
rhs = y[tile_k, tile_n]
417+
acc += hl.dot(lhs, rhs)
418+
out[tile_m, tile_n] = acc
419+
return out
420+
421+
x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32)
422+
y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32)
423+
424+
self._assert_warning_in_stderr(
425+
no_warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]"
426+
)
427+
291428
# Note: numerical behavior for differing acc dtype is covered by existing dot tests; here we focus on codegen shape
292429

293430
# torch.baddbmm codegen shape is covered indirectly by broader matmul tests; skipping a brittle code-inspection here

0 commit comments

Comments
 (0)