Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9c3e56a
Adding UCBPatternSearch autotuner and fragment_encoder
Nov 11, 2025
8035c8e
add tests and import
Nov 11, 2025
e24b529
moved config encoding to config_fragment
Nov 13, 2025
e5be414
Merge remote-tracking branch 'upstream/main' into ucb_pattern_search
Nov 14, 2025
140b3b2
adapt ucb_pattern_search to new encoder
Nov 14, 2025
ca99252
merged new config fragment
Nov 14, 2025
899d30d
imports
Nov 14, 2025
c6fc4cc
Merge branch 'main' into ucb_pattern_search
ethche Nov 14, 2025
13ea3b4
remove encode_scalar
Nov 14, 2025
9cf7dbe
Merge branch 'ucb_pattern_search' of https://github.com/ethche/helion…
Nov 14, 2025
3b739f5
fix imports
Nov 15, 2025
62aaf77
early stopping helper for pattern search
Nov 15, 2025
97963f5
fix tests
Nov 16, 2025
a0ef224
fix dim
Nov 16, 2025
018a626
ucb fix lints and better hyperparams
Nov 16, 2025
2a65701
revert linter changes
Nov 16, 2025
79c0fa3
name change
Nov 16, 2025
c2a578f
revert unrelated changes in config_generation
Nov 16, 2025
fc2929d
revert unrelated changes in config_generation
Nov 16, 2025
44c925d
save gp state
Nov 16, 2025
70a46fe
better ucb docstring
Nov 16, 2025
f837953
combined dependencies
Nov 16, 2025
2c2eec9
fix pyproject
Nov 16, 2025
74b0754
reverting unrelated changes to comments
Nov 16, 2025
d9cce1e
no need for encode for integer fragment, inherit from base integer
Nov 16, 2025
3437347
Merge branch 'main' into ucb_pattern_search
ethche Nov 16, 2025
4a791a9
optimize batch UCB function, simplify batch selection
Nov 17, 2025
913b330
Merge branch 'ucb_pattern_search' of https://github.com/ethche/helion…
Nov 17, 2025
047509d
batch optimization by default
Nov 17, 2025
eb430ce
LFBO instead of ucb_pattern_search
Nov 18, 2025
33803df
LFBO tests
Nov 18, 2025
b6c24cc
LFBO better docstring
Nov 18, 2025
e3106a8
LFBO remove patience feature
Nov 18, 2025
1c810f0
LFBO imports
Nov 18, 2025
c25362d
Fix comments
Nov 18, 2025
b5a65be
Fix test names
Nov 18, 2025
e30bdad
Merge branch 'main' into ucb_pattern_search
ethche Nov 18, 2025
062be0c
remove comma
Nov 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
- name: Install Helion
run: |
source .venv/bin/activate
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]'
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate,bayesopt]'
python -c "import helion; print(helion.__name__)"

- name: Install Benchmark Requirements
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ jobs:
run: |
source .venv/bin/activate
uv pip install setuptools ninja
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]'
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate,bayesopt]'
python -c "import helion; print(helion.__name__)"

- name: Run Tests
Expand Down
30 changes: 19 additions & 11 deletions helion/autotuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
from __future__ import annotations

from .config_fragment import BooleanFragment as BooleanFragment
from .config_fragment import EnumFragment as EnumFragment
from .config_fragment import IntegerFragment as IntegerFragment
from .config_fragment import ListOf as ListOf
from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment
from .config_fragment import (
BooleanFragment as BooleanFragment,
EnumFragment as EnumFragment,
IntegerFragment as IntegerFragment,
ListOf as ListOf,
PowerOfTwoFragment as PowerOfTwoFragment,
)
from .config_spec import ConfigSpec as ConfigSpec
from .de_surrogate_hybrid import DESurrogateHybrid as DESurrogateHybrid
from .differential_evolution import (
DifferentialEvolutionSearch as DifferentialEvolutionSearch,
)
from .effort_profile import AutotuneEffortProfile as AutotuneEffortProfile
from .effort_profile import DifferentialEvolutionConfig as DifferentialEvolutionConfig
from .effort_profile import PatternSearchConfig as PatternSearchConfig
from .effort_profile import RandomSearchConfig as RandomSearchConfig
from .effort_profile import (
AutotuneEffortProfile as AutotuneEffortProfile,
DifferentialEvolutionConfig as DifferentialEvolutionConfig,
PatternSearchConfig as PatternSearchConfig,
RandomSearchConfig as RandomSearchConfig,
)
from .finite_search import FiniteSearch as FiniteSearch
from .local_cache import LocalAutotuneCache as LocalAutotuneCache
from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache
from .local_cache import (
LocalAutotuneCache as LocalAutotuneCache,
StrictLocalAutotuneCache as StrictLocalAutotuneCache,
)
from .pattern_search import PatternSearch as PatternSearch
from .random_search import RandomSearch as RandomSearch
from .ucb_pattern_search import UCBPatternSearch

search_algorithms = {
"DESurrogateHybrid": DESurrogateHybrid,
"UCBPatternSearch": UCBPatternSearch,
"DifferentialEvolutionSearch": DifferentialEvolutionSearch,
"FiniteSearch": FiniteSearch,
"PatternSearch": PatternSearch,
Expand Down
95 changes: 81 additions & 14 deletions helion/autotuner/config_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import dataclasses
import enum
import random
from typing import Iterable
from typing import TypeGuard
from typing import cast
from typing import cast, Iterable, TypeGuard

from ..exc import InvalidConfig

Expand Down Expand Up @@ -51,6 +49,21 @@ def differential_mutation(self, a: object, b: object, c: object) -> object:
def is_block_size(self) -> bool:
return False

def is_categorical(self) -> bool:
return True

def encode_dim(self) -> int:
"""
Returns the dimension of the output of encode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this is? What is it encoding?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the dimension of the encoding

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the name to be more clear

"""
raise NotImplementedError

def encode(self, value: object) -> list[float]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse encode_scalar here? If not, I'd like to combine this with encode_scalar since they are solving the same problem.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel This is pretty much identical to encode_scalar for integer and poweroftwo. But the previous encode_scalar did not have any functionality for ListOf or PermutationFragments. How should we handle those two?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to rename everything to encode_scalar, but given ListOf and PermutationFragments, I think we should allow encode to output a list of floats

"""
Returns a list of floats that can be used to encode the value of this fragment.
"""
raise NotImplementedError

def get_minimum(self) -> int:
"""
Return the minimum allowed value for this fragment.
Expand Down Expand Up @@ -106,6 +119,17 @@ def pattern_neighbors(self, current: object) -> list[object]:
neighbors.append(swapped)
return neighbors

def encode_dim(self) -> int:
return self.length * self.length

def encode(self, value: object) -> list[float]:
encoded = []
for pos in range(self.length):
val = value[pos]
for v in range(self.length):
encoded.append(1.0 if v == val else 0.0)
return encoded


@dataclasses.dataclass
class BaseIntegerFragment(ConfigSpecFragment):
Expand All @@ -126,6 +150,9 @@ def default(self) -> int:
def clamp(self, val: int) -> int:
return max(min(val, self.high), self.low)

def is_categorical(self) -> bool:
return False

def get_minimum(self) -> int:
return self.low

Expand All @@ -141,13 +168,19 @@ def pattern_neighbors(self, current: object) -> list[object]:
neighbors.append(upper)
return neighbors

def encode_scalar(self, value: object) -> float:
"""Encode integer values directly as floats."""
if not isinstance(value, (int, float)):
raise TypeError(
f"Expected int/float for BaseIntegerFragment, got {type(value).__name__}: {value!r}"
)
return float(value)
def encode_dim(self) -> int:
return 1

def encode(self, value: object) -> list[float]:
"""Encode enum values as their index."""
try:
choice_idx = self.choices.index(value)
except ValueError:
raise ValueError(
f"Invalid enum value {value!r} for EnumFragment. "
f"Valid choices: {self.choices}"
) from None
return [float(choice_idx)]


class PowerOfTwoFragment(BaseIntegerFragment):
Expand Down Expand Up @@ -180,7 +213,10 @@ def differential_mutation(self, a: object, b: object, c: object) -> int:
return self.clamp(ai * 2)
return ai

def encode_scalar(self, value: object) -> float:
def encode_dim(self) -> int:
return 1

def encode(self, value: object) -> list[float]:
"""Encode power-of-2 values using log2 transformation."""
import math

Expand All @@ -192,7 +228,7 @@ def encode_scalar(self, value: object) -> float:
raise ValueError(
f"Expected positive value for PowerOfTwoFragment, got {value}"
)
return math.log2(float(value))
return [math.log2(float(value))]


class IntegerFragment(BaseIntegerFragment):
Expand All @@ -211,6 +247,13 @@ def differential_mutation(self, a: object, b: object, c: object) -> int:
return self.clamp(a + 1)
return a

def encode_dim(self) -> int:
return 1

def encode(self, value: object) -> list[float]:
assert isinstance(value, int)
return [float(value)]


@dataclasses.dataclass
class EnumFragment(ConfigSpecFragment):
Expand All @@ -235,7 +278,10 @@ def differential_mutation(self, a: object, b: object, c: object) -> object:
choices.remove(a)
return random.choice(choices)

def encode_scalar(self, value: object) -> float:
def encode_dim(self) -> int:
return len(self.choices)

def encode(self, value: object) -> list[float]:
"""Encode enum values as their index."""
try:
choice_idx = self.choices.index(value)
Expand All @@ -244,7 +290,7 @@ def encode_scalar(self, value: object) -> float:
f"Invalid enum value {value!r} for EnumFragment. "
f"Valid choices: {self.choices}"
) from None
return float(choice_idx)
return [1.0 if i == choice_idx else 0.0 for i in range(len(self.choices))]


class BooleanFragment(ConfigSpecFragment):
Expand All @@ -265,6 +311,14 @@ def differential_mutation(self, a: object, b: object, c: object) -> bool:
return a
return not a

def encode_dim(self) -> int:
return 1

def encode(self, value: object) -> list[float]:
"""Encode enum values as their index."""
assert isinstance(value, bool)
return [1.0] if value else [0.0]


class BlockSizeFragment(PowerOfTwoFragment):
def category(self) -> Category:
Expand Down Expand Up @@ -296,6 +350,9 @@ def random(self) -> list[object]:
"""Return a list of random values."""
return [self.inner.random() for _ in range(self.length)]

def is_categorical(self) -> bool:
return self.inner.is_categorical()

def pattern_neighbors(self, current: object) -> list[object]:
"""Return neighbors by changing one element at a time."""
if not isinstance(current, list) or len(current) != self.length:
Expand All @@ -320,3 +377,13 @@ def differential_mutation(self, a: object, b: object, c: object) -> list[object]
self.inner.differential_mutation(a[i], b[i], c[i])
for i in range(self.length)
]

def encode_dim(self):
return self.length * self.inner.encode_dim()

def encode(self, value: object) -> list[float]:
assert isinstance(value, list)
encoded = []
for v in value:
encoded.extend(self.inner.encode(v))
return encoded
12 changes: 3 additions & 9 deletions helion/autotuner/config_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
import itertools
import operator
import random
from typing import TYPE_CHECKING
from typing import cast
from typing import cast, TYPE_CHECKING

from .._compat import warps_to_threads
from .config_fragment import Category
from .config_fragment import ConfigSpecFragment
from .config_fragment import PowerOfTwoFragment
from .config_fragment import Category, ConfigSpecFragment, PowerOfTwoFragment

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -185,20 +182,17 @@ def differential_mutation(
def encode_config(self, flat_config: FlatConfig) -> list[float]:
"""
Encode a flat configuration into a numerical vector for ML models.

This is used by surrogate-assisted algorithms (e.g., DE-Surrogate) that need
to represent configurations as continuous vectors for prediction models.

Args:
flat_config: The flat configuration values to encode.

Returns:
A list of floats representing the encoded configuration.
"""
encoded: list[float] = []

for flat_idx, spec in enumerate(self.flat_spec):
value = flat_config[flat_idx]
encoded.append(spec.encode_scalar(value))
encoded.extend(spec.encode(value))

return encoded
Loading
Loading