From a51afd4afa41aa1ad06503804f5586bdfd020498 Mon Sep 17 00:00:00 2001 From: Shikhar Mishra Date: Fri, 31 Oct 2025 00:23:43 +0530 Subject: [PATCH 1/4] Major: Used a zero-nonzero specialization in which 0 will be dealt with in a diff way while 1 just goes >= 2 similar to already present '_tensor_key' logic & this can be turned on with an environment variable 'HELION_SHAPE_BUCKETING' that default to 'min2', and disables 0/1 when using 'zero_nonzero'. --- helion/runtime/kernel.py | 15 ++++++++-- helion/runtime/settings.py | 20 +++++++++++++ test/test_shape_bucketing.py | 54 ++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 test/test_shape_bucketing.py diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 44ea26e35..4abab1700 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -391,6 +391,13 @@ def configs(self) -> list[Config]: def format_kernel_decorator(self, config: Config, settings: Settings) -> str: """Return the @helion.kernel decorator snippet capturing configs and settings that influence Triton code generation.""" + # Include shape_bucketing only when non-default to keep logs compact + if getattr(settings, "shape_bucketing", "min2") != "min2": + return ( + f"@helion.kernel(config={config.__repr__()}, " + f"static_shapes={settings.static_shapes}, " + f"shape_bucketing='{settings.shape_bucketing}')" + ) return f"@helion.kernel(config={config.__repr__()}, static_shapes={settings.static_shapes})" def to_triton_code( @@ -733,11 +740,15 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable: (*obj.size(),), (*obj.stride(),), ) + # Non-static path: bucket sizes for specialization. Default is 0/1/>=2 (as 2). + vals = tuple([min(s, 2) for s in obj.size()]) + if getattr(fn.settings, "shape_bucketing", "min2") == "zero_nonzero": + # Keep zero distinct; unify 1 with >=2 to reduce variant churn + vals = tuple(0 if v == 0 else 2 for v in vals) return ( obj.dtype, obj.device.type, - # 0, 1, or >=2 specialization - tuple([min(s, 2) for s in obj.size()]), + vals, ) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 48aa1e97c..5ce3861b1 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -343,6 +343,20 @@ class _Settings: _env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False ) ) + # Controls non-static shape specialization bucketing. When "min2" (default), + # we bucket dynamic sizes per-dimension into 0, 1, or >=2 (represented as 2). + # When "zero_nonzero", we keep 0 distinct and unify 1 with >=2 to reduce variant churn. + shape_bucketing: Literal["min2", "zero_nonzero"] = dataclasses.field( + default_factory=functools.partial( + _env_get_literal, + "HELION_SHAPE_BUCKETING", + cast('Literal["min2", "zero_nonzero"]', "min2"), + mapping={ + "min2": "min2", + "zero_nonzero": "zero_nonzero", + }, + ) + ) # pyright: ignore[reportArgumentType] ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode) autotuner_fn: AutotunerFunction = default_autotuner_fn @@ -395,6 +409,12 @@ class Settings(_Settings): ), "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", "debug_dtype_asserts": "If True, emit tl.static_assert checks for dtype after each device node.", + "shape_bucketing": ( + "Dynamic-shape specialization policy when static_shapes=False. " + "'min2' buckets each dimension into 0,1,>=2 (current behavior). " + "'zero_nonzero' keeps 0 distinct and unifies 1 with >=2 to reduce variants. " + "Override with HELION_SHAPE_BUCKETING=min2|zero_nonzero." + ), "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", "autotuner_fn": ( "Function to create an autotuner. " diff --git a/test/test_shape_bucketing.py b/test/test_shape_bucketing.py new file mode 100644 index 000000000..e354fefb3 --- /dev/null +++ b/test/test_shape_bucketing.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import unittest + +import torch + +from helion.runtime.kernel import kernel +from helion.runtime.settings import Settings + + +def _dummy(x: torch.Tensor) -> torch.Tensor: + return x + + +class TestShapeBucketing(unittest.TestCase): + def test_min2_bucketing_default(self) -> None: + k = kernel(_dummy, settings=Settings(static_shapes=False)) + + t0 = torch.empty(0, 3) + t1 = torch.empty(1, 3) + t2 = torch.empty(2, 3) + t7 = torch.empty(7, 3) + + key_0 = k.specialization_key([t0]) + key_1 = k.specialization_key([t1]) + key_2 = k.specialization_key([t2]) + key_7 = k.specialization_key([t7]) + + # min2: 0,1,>=2 (as 2) + self.assertNotEqual(key_0, key_2) + self.assertNotEqual(key_1, key_2) + self.assertEqual(key_2, key_7) + + def test_zero_nonzero_bucketing(self) -> None: + k = kernel( + _dummy, + settings=Settings(static_shapes=False, shape_bucketing="zero_nonzero"), + ) + + t0 = torch.empty(0, 3) + t1 = torch.empty(1, 3) + t2 = torch.empty(2, 3) + + key_0 = k.specialization_key([t0]) + key_1 = k.specialization_key([t1]) + key_2 = k.specialization_key([t2]) + + # zero_nonzero: keep 0 distinct; unify 1 with >=2 + self.assertNotEqual(key_0, key_2) + self.assertEqual(key_1, key_2) + + +if __name__ == "__main__": + unittest.main() From eb13f0c11529c20a3b8b6f9b4e7b6a278ecdae51 Mon Sep 17 00:00:00 2001 From: Shikhar Mishra Date: Sat, 1 Nov 2025 13:54:25 +0530 Subject: [PATCH 2/4] Fix the type 'str' is not assignable lint error by using a helper function. --- helion/runtime/settings.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 3fc5214f0..51ab3956b 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -232,6 +232,21 @@ def _get_autotune_random_seed() -> int: return int(time.time() * 1000) % 2**32 +def _get_shape_bucketing() -> Literal["min2", "zero_nonzero"]: + val = _env_get_literal( + "HELION_SHAPE_BUCKETING", + "min2", + mapping={ + "min2": "min2", + "zero_nonzero": "zero_nonzero", + }, + ) + # Narrow to Literal explicitly + if val == "zero_nonzero": + return "zero_nonzero" + return "min2" + + def _get_ref_mode() -> RefMode: interpret = _env_get_bool("HELION_INTERPRET", False) return RefMode.EAGER if interpret else RefMode.OFF @@ -349,18 +364,10 @@ class _Settings: ) # Controls non-static shape specialization bucketing. When "min2" (default), # we bucket dynamic sizes per-dimension into 0, 1, or >=2 (represented as 2). - # When "zero_nonzero", we keep 0 distinct and unify 1 with >=2 to reduce variant churn. + # When "zero_nonzero", we keep 0 distinct and unify 1 with >=2 to reduce churn. shape_bucketing: Literal["min2", "zero_nonzero"] = dataclasses.field( - default_factory=functools.partial( - _env_get_literal, - "HELION_SHAPE_BUCKETING", - cast('Literal["min2", "zero_nonzero"]', "min2"), - mapping={ - "min2": "min2", - "zero_nonzero": "zero_nonzero", - }, - ) - ) # pyright: ignore[reportArgumentType] + default_factory=_get_shape_bucketing + ) ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode) autotuner_fn: AutotunerFunction = default_autotuner_fn autotune_baseline_fn: Callable[..., object] | None = None From ad805e8e29f584c32868ca2269710dd8159288de Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 30 Oct 2025 13:50:20 -0700 Subject: [PATCH 3/4] Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054) Rebased to main --- helion/autotuner/base_search.py | 3 +++ test/test_autotuner.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index a9e841f68..fb94bcdc7 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -188,7 +188,10 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: baseline_config, prefix=f"Generated Triton code for {decorator}:", ) +<<<<<<< HEAD self.kernel.maybe_log_repro(self.log.error, new_args, baseline_config) +======= +>>>>>>> 69f3405 (Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054)) raise exc.InvalidConfig( "Default config failed while computing baseline.\n" f"Default config: {decorator}\n" diff --git a/test/test_autotuner.py b/test/test_autotuner.py index ce2d4abdb..7cb1a99dc 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -597,7 +597,10 @@ def wrong_fn(*fn_args, **fn_kwargs): run_mode("fork", expect_error=False) run_mode("spawn", expect_error=True) +<<<<<<< HEAD @skipIfCpu("fails on Triton CPU backend") +======= +>>>>>>> 69f3405 (Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054)) def test_autotune_baseline_fn(self) -> None: """Test that custom baseline function is used for accuracy checking.""" config1 = helion.Config(block_sizes=[32], num_warps=4) @@ -638,7 +641,10 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # Verify the result is correct torch.testing.assert_close(result, args[0] + args[1]) +<<<<<<< HEAD @skipIfCpu("fails on Triton CPU backend") +======= +>>>>>>> 69f3405 (Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054)) def test_autotune_baseline_fn_filters_bad_config(self) -> None: """Test that custom baseline function correctly filters incorrect configs.""" bad_config = helion.Config(block_sizes=[1], num_warps=8) @@ -737,7 +743,10 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ): add(*args) +<<<<<<< HEAD @skipIfCpu("fails on Triton CPU backend") +======= +>>>>>>> 69f3405 (Add `settings.autotune_baseline_fn` to allow passing in custom baseline function to autotuner (#1054)) def test_max_generations(self): """Autotuner max generation respects explicit kwargs then setting override.""" From 20234fd768a7a652f757ee5e42a8f857dcbd8f8b Mon Sep 17 00:00:00 2001 From: Shikhar Mishra Date: Wed, 5 Nov 2025 23:10:48 +0530 Subject: [PATCH 4/4] Fix the shape-bucketing error where FakeTensors coming from ShapeEnv would still get specialized. --- helion/_compiler/compile_environment.py | 28 ++- helion/_compiler/tile_strategy.py | 4 +- test/test_shape_bucketing.expected | 62 ++++++ test/test_shape_bucketing.py | 256 +++++++++++++++++++++++- 4 files changed, 347 insertions(+), 3 deletions(-) create mode 100644 test/test_shape_bucketing.expected diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index d04e44bb1..044541e64 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -80,7 +80,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None: self.device = device self.settings = settings self.shape_env = ShapeEnv( - specialize_zero_one=True, + specialize_zero_one=(settings.shape_bucketing == "min2"), duck_shape=False, assume_static_by_default=settings.static_shapes, ) @@ -336,6 +336,32 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor: result = self.fake_mode.fake_tensor_converter.from_real_tensor( self.fake_mode, tensor, shape_env=self.shape_env, source=source ) + # When disabling 0/1 specialization (zero_nonzero), ensure non-zero dims are symbolic + if ( + not self.settings.static_shapes + and getattr(self.settings, "shape_bucketing", "min2") == "zero_nonzero" + ): + sizes = list(result.size()) + need_replace = False + for i, s in enumerate(sizes): + # Keep zero distinct; symbolize any non-zero concrete int + if isinstance(s, int) and s != 0: + sym = self.cached_create_unbacked_symint(key=(source, "size", i)) + sizes[i] = sym + need_replace = True + # Record a friendly debug name for this dimension if available + if isinstance(source, LocalSource): + self.debug_shape_renames[sym._sympy_()] = sympy.Symbol( + f"{source.local_name}_size{i}", integer=True + ) + if need_replace: + # Recreate a FakeTensor with symbolic sizes, preserving stride/dtype/device + result = torch.empty_strided( + tuple(sizes), + result.stride(), + dtype=result.dtype, + device=result.device, + ) self.input_sources[result] = source if isinstance(source, LocalSource): for i, s in enumerate(result.size()): diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index 67a3d6169..65016e0fd 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -663,7 +663,9 @@ def select_pid_strategy(self) -> ProgramIDs: def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST: if isinstance(x, ast.AST): - if to_dtype: + # Casting with .to(...) is only valid for Triton tensor expressions + # Skip casting for simple names/constants representing host scalar args + if to_dtype and not isinstance(x, (ast.Name, ast.Constant)): return expr_from_string(f"{{value}}.to({to_dtype})", value=x) return x if isinstance(x, int): diff --git a/test/test_shape_bucketing.expected b/test/test_shape_bucketing.expected new file mode 100644 index 000000000..04950bdb4 --- /dev/null +++ b/test/test_shape_bucketing.expected @@ -0,0 +1,62 @@ +This file is automatically generated by assertExpectedJournal calls in test_shape_bucketing.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestShapeBucketing.test_zero_nonzero_codegen_identical_m1_vs_m2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_shape_bucketing as _source_module + +@triton.jit +def _helion_pw_add_fn(x, out, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1): + # src[test_shape_bucketing.py:N]: for i in grid(x.size(0)): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + # src[test_shape_bucketing.py:N]: for j in grid(x.size(1)): + # src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0 + for offset_1 in tl.range(0, x_size_1): + # src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0 + load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None) + v_0 = 1.0 + v_1 = load + v_0 + tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None) + +def pw_add_fn(x: torch.Tensor, out: torch.Tensor, *, _launcher=_default_launcher): + # src[test_shape_bucketing.py:N]: for i in grid(x.size(0)): + # src[test_shape_bucketing.py:N]: for j in grid(x.size(1)): + # src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0 + _launcher(_helion_pw_add_fn, (x.size(0),), x, out, x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), num_warps=4, num_stages=1) + +--- assertExpectedJournal(TestShapeBucketing.test_zero_nonzero_codegen_identical_m1_vs_m2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_shape_bucketing as _source_module + +@triton.jit +def _helion_pw_add_fn(x, out, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1): + # src[test_shape_bucketing.py:N]: for i in grid(x.size(0)): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + # src[test_shape_bucketing.py:N]: for j in grid(x.size(1)): + # src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0 + for offset_1 in tl.range(0, x_size_1): + # src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0 + load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None) + v_0 = 1.0 + v_1 = load + v_0 + tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None) + +def pw_add_fn(x: torch.Tensor, out: torch.Tensor, *, _launcher=_default_launcher): + # src[test_shape_bucketing.py:N]: for i in grid(x.size(0)): + # src[test_shape_bucketing.py:N]: for j in grid(x.size(1)): + # src[test_shape_bucketing.py:N]: out[i, j] = x[i, j] + 1.0 + _launcher(_helion_pw_add_fn, (x.size(0),), x, out, x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), num_warps=4, num_stages=1) diff --git a/test/test_shape_bucketing.py b/test/test_shape_bucketing.py index e354fefb3..a0c3b1a0d 100644 --- a/test/test_shape_bucketing.py +++ b/test/test_shape_bucketing.py @@ -4,6 +4,9 @@ import torch +from helion._testing import TestCase +from helion._testing import skipIfNotCUDA +from helion.language import grid from helion.runtime.kernel import kernel from helion.runtime.settings import Settings @@ -12,7 +15,7 @@ def _dummy(x: torch.Tensor) -> torch.Tensor: return x -class TestShapeBucketing(unittest.TestCase): +class TestShapeBucketing(TestCase): def test_min2_bucketing_default(self) -> None: k = kernel(_dummy, settings=Settings(static_shapes=False)) @@ -49,6 +52,257 @@ def test_zero_nonzero_bucketing(self) -> None: self.assertNotEqual(key_0, key_2) self.assertEqual(key_1, key_2) + @skipIfNotCUDA() + def test_zero_nonzero_runtime_correctness(self) -> None: + # A simple pointwise kernel to exercise runtime reuse across 1 vs >=2 shapes + @kernel( + settings=Settings( + static_shapes=False, + shape_bucketing="zero_nonzero", + autotune_effort="none", + ) + ) + def pw_add(x: torch.Tensor, out: torch.Tensor) -> None: + for i in grid(x.size(0)): + for j in grid(x.size(1)): + out[i, j] = x[i, j] + 1.0 + + device = torch.device("cuda", 0) + K = 16 + + # Compile with M=2 first (general), then reuse for M=1 (singleton) + x2 = torch.randn(2, K, device=device, dtype=torch.float32) + x1 = torch.randn(1, K, device=device, dtype=torch.float32) + + y2 = torch.empty_like(x2) + y1 = torch.empty_like(x1) + pw_add(x2, y2) # compile + pw_add(x1, y1) # reuse + + torch.testing.assert_close(y2, x2 + 1.0, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(y1, x1 + 1.0, rtol=1e-4, atol=1e-4) + + @skipIfNotCUDA() + def test_codegen_differs_for_singleton(self) -> None: + # Define a simple kernel function (not decorated) so we can construct two Kernel instances + def pw_add_fn(x: torch.Tensor, out: torch.Tensor) -> None: + for i in grid(x.size(0)): + for j in grid(x.size(1)): + out[i, j] = x[i, j] + 1.0 + + device = torch.device("cuda", 0) + K = 16 + + # Use min2 to force distinct specialization keys per shape and avoid reuse between 1 and 2 + settings = Settings( + static_shapes=False, autotune_effort="none", shape_bucketing="min2" + ) + + k1 = kernel(pw_add_fn, settings=settings) + k2 = kernel(pw_add_fn, settings=settings) + + x1 = torch.randn(1, K, device=device, dtype=torch.float32) + y1 = torch.empty_like(x1) + b1 = k1.bind((x1, y1)) + code1 = b1.to_triton_code() + + x2 = torch.randn(2, K, device=device, dtype=torch.float32) + y2 = torch.empty_like(x2) + b2 = k2.bind((x2, y2)) + code2 = b2.to_triton_code() + + if code1 == code2: + self.skipTest( + "Generated Triton is identical for M=1 and M=2; no singleton specialization detected" + ) + else: + # Expect differing code paths when singleton specialization is present + self.assertNotEqual(code1, code2) + + @skipIfNotCUDA() + def test_zero_nonzero_general_only_single_compile(self) -> None: + # Compile first with M=1, then call with M=2 under zero_nonzero; ensure a single compiled callable is reused + @kernel( + settings=Settings( + static_shapes=False, + shape_bucketing="zero_nonzero", + autotune_effort="none", + ) + ) + def pw_add(x: torch.Tensor, out: torch.Tensor) -> None: + for i in grid(x.size(0)): + for j in grid(x.size(1)): + out[i, j] = x[i, j] + 1.0 + + device = torch.device("cuda", 0) + K = 16 + + x1 = torch.randn(1, K, device=device, dtype=torch.float32) + y1 = torch.empty_like(x1) + x2 = torch.randn(2, K, device=device, dtype=torch.float32) + y2 = torch.empty_like(x2) + + # Bind on M=1 to capture the bound kernel instance + b = pw_add.bind((x1, y1)) + + # First call (M=1) → compile once + pw_add(x1, y1) + torch.testing.assert_close(y1, x1 + 1.0, rtol=1e-4, atol=1e-4) + self.assertEqual(len(b._compile_cache), 1) + + # Second call (M=2) → should not compile again + pw_add(x2, y2) + torch.testing.assert_close(y2, x2 + 1.0, rtol=1e-4, atol=1e-4) + self.assertEqual(len(b._compile_cache), 1) + + # Subsequent calls should reuse without increasing cache entries + num_entries = len(b._compile_cache) + pw_add(x2, y2) + pw_add(x1, y1) + self.assertEqual(len(b._compile_cache), num_entries) + + @skipIfNotCUDA() + def test_zero_nonzero_runtime_correctness_varying_singleton_dim_row_to_col( + self, + ) -> None: + # Compile at (1, K) then run at (K, 1) under zero_nonzero; must be correct and reuse single compiled callable + @kernel( + settings=Settings( + static_shapes=False, + shape_bucketing="zero_nonzero", + autotune_effort="none", + ) + ) + def pw_add(x: torch.Tensor, out: torch.Tensor) -> None: + for i in grid(x.size(0)): + for j in grid(x.size(1)): + out[i, j] = x[i, j] + 1.0 + + device = torch.device("cuda", 0) + K = 16 + + x_row = torch.randn(1, K, device=device, dtype=torch.float32) + y_row = torch.empty_like(x_row) + x_col = torch.randn(K, 1, device=device, dtype=torch.float32) + y_col = torch.empty_like(x_col) + + # Bind on (1, K) to capture the bound kernel instance + b = pw_add.bind((x_row, y_row)) + + # First call compiles once + pw_add(x_row, y_row) + torch.testing.assert_close(y_row, x_row + 1.0, rtol=1e-4, atol=1e-4) + self.assertEqual(len(b._compile_cache), 1) + + # Flip which dim is 1; still correct; cache unchanged + pw_add(x_col, y_col) + torch.testing.assert_close(y_col, x_col + 1.0, rtol=1e-4, atol=1e-4) + self.assertEqual(len(b._compile_cache), 1) + + @skipIfNotCUDA() + def test_zero_nonzero_runtime_correctness_varying_singleton_dim_col_to_row( + self, + ) -> None: + # Compile at (K, 1) then run at (1, K) under zero_nonzero; must be correct and reuse single compiled callable + @kernel( + settings=Settings( + static_shapes=False, + shape_bucketing="zero_nonzero", + autotune_effort="none", + ) + ) + def pw_add(x: torch.Tensor, out: torch.Tensor) -> None: + for i in grid(x.size(0)): + for j in grid(x.size(1)): + out[i, j] = x[i, j] + 1.0 + + device = torch.device("cuda", 0) + K = 16 + + x_col = torch.randn(K, 1, device=device, dtype=torch.float32) + y_col = torch.empty_like(x_col) + x_row = torch.randn(1, K, device=device, dtype=torch.float32) + y_row = torch.empty_like(x_row) + + b = pw_add.bind((x_col, y_col)) + + # First call compiles once + pw_add(x_col, y_col) + torch.testing.assert_close(y_col, x_col + 1.0, rtol=1e-4, atol=1e-4) + self.assertEqual(len(b._compile_cache), 1) + + # Flip which dim is 1; still correct; cache unchanged + pw_add(x_row, y_row) + torch.testing.assert_close(y_row, x_row + 1.0, rtol=1e-4, atol=1e-4) + self.assertEqual(len(b._compile_cache), 1) + + @skipIfNotCUDA() + def test_zero_nonzero_codegen_identical_m1_vs_m2(self) -> None: + # Under zero_nonzero, M=1 vs M=2 should produce identical codegen + def pw_add_fn(x: torch.Tensor, out: torch.Tensor) -> None: + for i in grid(x.size(0)): + for j in grid(x.size(1)): + out[i, j] = x[i, j] + 1.0 + + device = torch.device("cuda", 0) + K = 16 + settings = Settings( + static_shapes=False, autotune_effort="none", shape_bucketing="zero_nonzero" + ) + + k1 = kernel(pw_add_fn, settings=settings) + k2 = kernel(pw_add_fn, settings=settings) + + x1 = torch.randn(1, K, device=device, dtype=torch.float32) + y1 = torch.empty_like(x1) + b1 = k1.bind((x1, y1)) + code1 = b1.to_triton_code() + self.assertExpectedJournal(code1) + + x2 = torch.randn(2, K, device=device, dtype=torch.float32) + y2 = torch.empty_like(x2) + b2 = k2.bind((x2, y2)) + code2 = b2.to_triton_code() + self.assertExpectedJournal(code2) + + self.assertEqual(code1, code2) + + @skipIfNotCUDA() + def test_zero_nonzero_runtime_correctness_varying_singleton_dim_3d(self) -> None: + # Compile at (1, K, K) then run across different 3D 1-ness patterns; must be correct and reuse a single compiled callable + @kernel( + settings=Settings( + static_shapes=False, + shape_bucketing="zero_nonzero", + autotune_effort="none", + ) + ) + def pw_add3d(x: torch.Tensor, out: torch.Tensor) -> None: + for i in grid(x.size(0)): + for j in grid(x.size(1)): + for k in grid(x.size(2)): + out[i, j, k] = x[i, j, k] + 1.0 + + device = torch.device("cuda", 0) + K = 8 + + x100 = torch.randn(1, K, K, device=device, dtype=torch.float32) + y100 = torch.empty_like(x100) + + # Bind and compile once with (1, K, K) + b = pw_add3d.bind((x100, y100)) + pw_add3d(x100, y100) + torch.testing.assert_close(y100, x100 + 1.0, rtol=1e-4, atol=1e-4) + self.assertEqual(len(b._compile_cache), 1) + + # Now flip which dimension is 1 across various patterns, correctness should hold and cache size remain 1 + for shape in [(K, 1, K), (K, K, 1), (1, 1, K), (1, K, 1), (K, 1, 1)]: + x = torch.randn(*shape, device=device, dtype=torch.float32) + y = torch.empty_like(x) + pw_add3d(x, y) + torch.testing.assert_close(y, x + 1.0, rtol=1e-4, atol=1e-4) + self.assertEqual(len(b._compile_cache), 1) + if __name__ == "__main__": unittest.main()