Skip to content

Commit c16969b

Browse files
authored
Print Triton code when error for easier debugging (#874)
1 parent dc53ff9 commit c16969b

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

helion/autotuner/base_search.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,11 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
134134
decorator = self.kernel.format_kernel_decorator(
135135
baseline_config, self.settings
136136
)
137+
triton_code = self.kernel.to_triton_code(baseline_config)
137138
raise exc.InvalidConfig(
138139
"Default config failed while computing baseline.\n"
139140
f"Default config: {decorator}\n"
141+
f"\nGenerated Triton code:\n{triton_code}\n"
140142
) from e
141143
original_args_flat, _ = tree_flatten(self._original_args)
142144
new_args_flat, _ = tree_flatten(new_args)
@@ -235,9 +237,10 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
235237
raise exc.TritonError(
236238
f"{type(e).__qualname__}: {e}",
237239
self.kernel.format_kernel_decorator(config, self.settings),
240+
self.kernel.to_triton_code(config),
238241
) from e
239242
if action == "warn":
240-
self.log.warning(format_triton_compile_failure(config, e))
243+
self.log.warning(format_triton_compile_failure(config, e, self.kernel))
241244
else:
242245
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
243246
return inf
@@ -277,13 +280,18 @@ def extract_launcher(
277280
# Should not reach here
278281
raise RuntimeError("Expected _ExtractedLaunchArgs exception")
279282
except _ExtractedLaunchArgs as e:
280-
precompiler = make_precompiler(e.kernel, config)(*e.args, **e.kwargs)
283+
precompiler = make_precompiler(
284+
e.kernel,
285+
config,
286+
self.kernel,
287+
)(*e.args, **e.kwargs)
281288
if precompiler is already_compiled:
282289
return PrecompileFuture.skip(self, config, True)
283290
except Exception:
284291
log.warning(
285-
"Helion autotuner precompile error for %s",
292+
"Helion autotuner precompile error for %s\n\nGenerated Triton code:\n%s",
286293
self.kernel.format_kernel_decorator(config, self.settings),
294+
self.kernel.to_triton_code(config),
287295
exc_info=True,
288296
)
289297
raise

helion/autotuner/logger.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
if TYPE_CHECKING:
1616
from ..runtime.config import Config
17+
from ..runtime.kernel import BoundKernel
1718

1819

1920
class LambdaLogger:
@@ -92,12 +93,19 @@ def _maybe_call(fn: Callable[[], str] | str) -> str:
9293
return fn
9394

9495

95-
def format_triton_compile_failure(config: Config, err: BaseException) -> str:
96+
def format_triton_compile_failure(
97+
config: Config, err: BaseException, bound_kernel: BoundKernel
98+
) -> str:
99+
kernel_decorator = bound_kernel.format_kernel_decorator(
100+
config, bound_kernel.settings
101+
)
102+
triton_code = bound_kernel.to_triton_code(config)
96103
return (
97104
"Triton compile failed. This likely indicates a bug in Triton. "
98105
"Skipping failing config.\n"
99-
f"Config: {config!r}\n"
100-
f"Error: {type(err).__name__}: {err}"
106+
f"Config: {kernel_decorator}\n"
107+
f"Error: {type(err).__name__}: {err}\n\n"
108+
f"Generated Triton code:\n{triton_code}"
101109
)
102110

103111

helion/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ class TorchOpTracingError(_WrapException):
316316

317317

318318
class TritonError(BaseError):
319-
message = "Error running generated Triton program:\n{1}\n{0}"
319+
message = "Error running generated Triton program:\n{1}\n{0}\n\nGenerated Triton code:\n{2}"
320320

321321

322322
class BaseWarning(_FixedMessage):

helion/runtime/precompile_shim.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
from triton.runtime.jit import JITFunction
1414

1515
from .config import Config
16+
from .kernel import BoundKernel
1617

1718

1819
def make_precompiler(
19-
fn: JITFunction[object], config: Config
20+
fn: JITFunction[object],
21+
config: Config,
22+
bound_kernel: BoundKernel,
2023
) -> Callable[..., Callable[[], None]]:
2124
from triton.runtime.jit import find_paths_if
2225
from triton.runtime.jit import get_iterable_path
@@ -64,7 +67,10 @@ def finish_it() -> None:
6467
except Exception as e:
6568
action = classify_triton_exception(e)
6669
if action != "debug":
67-
print(format_triton_compile_failure(config, e), file=sys.stderr)
70+
print(
71+
format_triton_compile_failure(config, e, bound_kernel),
72+
file=sys.stderr,
73+
)
6874
sys.exit(1)
6975

7076
return finish_it

0 commit comments

Comments
 (0)