Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
self.device = device
self.settings = settings
self.shape_env = ShapeEnv(
specialize_zero_one=True,
specialize_zero_one=(settings.shape_bucketing == "min2"),
duck_shape=False,
assume_static_by_default=settings.static_shapes,
)
Expand Down Expand Up @@ -336,6 +336,32 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
result = self.fake_mode.fake_tensor_converter.from_real_tensor(
self.fake_mode, tensor, shape_env=self.shape_env, source=source
)
# When disabling 0/1 specialization (zero_nonzero), ensure non-zero dims are symbolic
if (
not self.settings.static_shapes
and getattr(self.settings, "shape_bucketing", "min2") == "zero_nonzero"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need getattr here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I was just trying to be on the safe side. Did the same in _kernel_type but it's really not necessary.

):
sizes = list(result.size())
need_replace = False
for i, s in enumerate(sizes):
# Keep zero distinct; symbolize any non-zero concrete int
if isinstance(s, int) and s != 0:
sym = self.cached_create_unbacked_symint(key=(source, "size", i))
sizes[i] = sym
need_replace = True
# Record a friendly debug name for this dimension if available
if isinstance(source, LocalSource):
self.debug_shape_renames[sym._sympy_()] = sympy.Symbol(
f"{source.local_name}_size{i}", integer=True
)
if need_replace:
# Recreate a FakeTensor with symbolic sizes, preserving stride/dtype/device
result = torch.empty_strided(
tuple(sizes),
result.stride(),
dtype=result.dtype,
device=result.device,
)
self.input_sources[result] = source
if isinstance(source, LocalSource):
for i, s in enumerate(result.size()):
Expand Down
4 changes: 3 additions & 1 deletion helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,9 @@ def select_pid_strategy(self) -> ProgramIDs:

def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST:
if isinstance(x, ast.AST):
if to_dtype:
# Casting with .to(...) is only valid for Triton tensor expressions
# Skip casting for simple names/constants representing host scalar args
if to_dtype and not isinstance(x, (ast.Name, ast.Constant)):
return expr_from_string(f"{{value}}.to({to_dtype})", value=x)
return x
if isinstance(x, int):
Expand Down
3 changes: 3 additions & 0 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
baseline_config,
prefix=f"Generated Triton code for {decorator}:",
)
<<<<<<< HEAD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me clear this up.

self.kernel.maybe_log_repro(self.log.error, new_args, baseline_config)
=======
>>>>>>> 69f3405 (Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054))
raise exc.InvalidConfig(
"Default config failed while computing baseline.\n"
f"Default config: {decorator}\n"
Expand Down
15 changes: 13 additions & 2 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,13 @@ def configs(self) -> list[Config]:

def format_kernel_decorator(self, config: Config, settings: Settings) -> str:
"""Return the @helion.kernel decorator snippet capturing configs and settings that influence Triton code generation."""
# Include shape_bucketing only when non-default to keep logs compact
if getattr(settings, "shape_bucketing", "min2") != "min2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why getattr?

return (
f"@helion.kernel(config={config.__repr__()}, "
f"static_shapes={settings.static_shapes}, "
f"shape_bucketing='{settings.shape_bucketing}')"
)
return f"@helion.kernel(config={config.__repr__()}, static_shapes={settings.static_shapes})"

def to_triton_code(
Expand Down Expand Up @@ -830,11 +837,15 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
(*obj.size(),),
(*obj.stride(),),
)
# Non-static path: bucket sizes for specialization. Default is 0/1/>=2 (as 2).
vals = tuple([min(s, 2) for s in obj.size()])
if getattr(fn.settings, "shape_bucketing", "min2") == "zero_nonzero":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

# Keep zero distinct; unify 1 with >=2 to reduce variant churn
vals = tuple(0 if v == 0 else 2 for v in vals)
return (
obj.dtype,
obj.device.type,
# 0, 1, or >=2 specialization
tuple([min(s, 2) for s in obj.size()]),
vals,
)


Expand Down
27 changes: 27 additions & 0 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,21 @@ def _get_autotune_random_seed() -> int:
return int(time.time() * 1000) % 2**32


def _get_shape_bucketing() -> Literal["min2", "zero_nonzero"]:
val = _env_get_literal(
"HELION_SHAPE_BUCKETING",
"min2",
mapping={
"min2": "min2",
"zero_nonzero": "zero_nonzero",
},
)
# Narrow to Literal explicitly
if val == "zero_nonzero":
return "zero_nonzero"
return "min2"


def _get_ref_mode() -> RefMode:
interpret = _env_get_bool("HELION_INTERPRET", False)
return RefMode.EAGER if interpret else RefMode.OFF
Expand Down Expand Up @@ -347,6 +362,12 @@ class _Settings:
_env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False
)
)
# Controls non-static shape specialization bucketing. When "min2" (default),
# we bucket dynamic sizes per-dimension into 0, 1, or >=2 (represented as 2).
# When "zero_nonzero", we keep 0 distinct and unify 1 with >=2 to reduce churn.
shape_bucketing: Literal["min2", "zero_nonzero"] = dataclasses.field(
default_factory=_get_shape_bucketing
)
Comment on lines +365 to +370
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some though, perhaps instead of adding a new config we should make static_shapes an enum of "all", "ones", "none". Since if I set static_shapes=True this does nothing.

We will need backcompat for True/False, but that might result in a cleaner config.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I was thinking we can do something like this:

  • static_shape = "all" would be equivalent to setting static_shape=True
  • static_shape = "ones" would be the "min2" case, meaning specialize 0/1.
  • static_shape = "none" would be this "zero_nonzero" case, basically disabling 0/1 specialization.

To make backcompat for True/False, we can set them as True->"all" & False->"none" and then HELION_STATIC_SHAPES can go through "all", "ones", "none".

ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode)
autotuner_fn: AutotunerFunction = default_autotuner_fn
autotune_baseline_fn: Callable[..., object] | None = None
Expand Down Expand Up @@ -401,6 +422,12 @@ class Settings(_Settings):
),
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
"debug_dtype_asserts": "If True, emit tl.static_assert checks for dtype after each device node.",
"shape_bucketing": (
"Dynamic-shape specialization policy when static_shapes=False. "
"'min2' buckets each dimension into 0,1,>=2 (current behavior). "
"'zero_nonzero' keeps 0 distinct and unifies 1 with >=2 to reduce variants. "
"Override with HELION_SHAPE_BUCKETING=min2|zero_nonzero."
),
"ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.",
"autotuner_fn": (
"Function to create an autotuner. "
Expand Down
9 changes: 9 additions & 0 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,10 @@ def wrong_fn(*fn_args, **fn_kwargs):
run_mode("fork", expect_error=False)
run_mode("spawn", expect_error=True)

<<<<<<< HEAD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops?

@skipIfCpu("fails on Triton CPU backend")
=======
>>>>>>> 69f3405 (Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054))
def test_autotune_baseline_fn(self) -> None:
"""Test that custom baseline function is used for accuracy checking."""
config1 = helion.Config(block_sizes=[32], num_warps=4)
Expand Down Expand Up @@ -638,7 +641,10 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# Verify the result is correct
torch.testing.assert_close(result, args[0] + args[1])

<<<<<<< HEAD
@skipIfCpu("fails on Triton CPU backend")
=======
>>>>>>> 69f3405 (Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054))
def test_autotune_baseline_fn_filters_bad_config(self) -> None:
"""Test that custom baseline function correctly filters incorrect configs."""
bad_config = helion.Config(block_sizes=[1], num_warps=8)
Expand Down Expand Up @@ -737,7 +743,10 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
):
add(*args)

<<<<<<< HEAD
@skipIfCpu("fails on Triton CPU backend")
=======
>>>>>>> 69f3405 (Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054))
def test_max_generations(self):
"""Autotuner max generation respects explicit kwargs then setting override."""

Expand Down
62 changes: 62 additions & 0 deletions test/test_shape_bucketing.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
This file is automatically generated by assertExpectedJournal calls in test_shape_bucketing.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestShapeBucketing.test_zero_nonzero_codegen_identical_m1_vs_m2)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

import test.test_shape_bucketing as _source_module

@triton.jit
def _helion_pw_add_fn(x, out, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1):
# src[test_shape_bucketing.py:N]: for i in grid(x.size(0)):
pid_0 = tl.program_id(0)
offset_0 = pid_0
# src[test_shape_bucketing.py:N]: for j in grid(x.size(1)):
# src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0
for offset_1 in tl.range(0, x_size_1):
# src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0
load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
v_0 = 1.0
v_1 = load + v_0
tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)

def pw_add_fn(x: torch.Tensor, out: torch.Tensor, *, _launcher=_default_launcher):
# src[test_shape_bucketing.py:N]: for i in grid(x.size(0)):
# src[test_shape_bucketing.py:N]: for j in grid(x.size(1)):
# src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0
_launcher(_helion_pw_add_fn, (x.size(0),), x, out, x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), num_warps=4, num_stages=1)

--- assertExpectedJournal(TestShapeBucketing.test_zero_nonzero_codegen_identical_m1_vs_m2)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

import test.test_shape_bucketing as _source_module

@triton.jit
def _helion_pw_add_fn(x, out, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1):
# src[test_shape_bucketing.py:N]: for i in grid(x.size(0)):
pid_0 = tl.program_id(0)
offset_0 = pid_0
# src[test_shape_bucketing.py:N]: for j in grid(x.size(1)):
# src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0
for offset_1 in tl.range(0, x_size_1):
# src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0
load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
v_0 = 1.0
v_1 = load + v_0
tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)

def pw_add_fn(x: torch.Tensor, out: torch.Tensor, *, _launcher=_default_launcher):
# src[test_shape_bucketing.py:N]: for i in grid(x.size(0)):
# src[test_shape_bucketing.py:N]: for j in grid(x.size(1)):
# src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0
_launcher(_helion_pw_add_fn, (x.size(0),), x, out, x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), num_warps=4, num_stages=1)
Loading
Loading