Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
9c3e56a
Adding UCBPatternSearch autotuner and fragment_encoder
Nov 11, 2025
8035c8e
add tests and import
Nov 11, 2025
e24b529
moved config encoding to config_fragment
Nov 13, 2025
e5be414
Merge remote-tracking branch 'upstream/main' into ucb_pattern_search
Nov 14, 2025
140b3b2
adapt ucb_pattern_search to new encoder
Nov 14, 2025
ca99252
merged new config fragment
Nov 14, 2025
899d30d
imports
Nov 14, 2025
c6fc4cc
Merge branch 'main' into ucb_pattern_search
ethche Nov 14, 2025
13ea3b4
remove encode_scalar
Nov 14, 2025
9cf7dbe
Merge branch 'ucb_pattern_search' of https://github.com/ethche/helion…
Nov 14, 2025
3b739f5
fix imports
Nov 15, 2025
62aaf77
early stopping helper for pattern search
Nov 15, 2025
97963f5
fix tests
Nov 16, 2025
a0ef224
fix dim
Nov 16, 2025
018a626
ucb fix lints and better hyperparams
Nov 16, 2025
2a65701
revert linter changes
Nov 16, 2025
79c0fa3
name change
Nov 16, 2025
c2a578f
revert unrelated changes in config_generation
Nov 16, 2025
fc2929d
revert unrelated changes in config_generation
Nov 16, 2025
44c925d
save gp state
Nov 16, 2025
70a46fe
better ucb docstring
Nov 16, 2025
f837953
combined dependencies
Nov 16, 2025
2c2eec9
fix pyproject
Nov 16, 2025
74b0754
reverting unrelated changes to comments
Nov 16, 2025
d9cce1e
no need for encode for integer fragment, inherit from base integer
Nov 16, 2025
3437347
Merge branch 'main' into ucb_pattern_search
ethche Nov 16, 2025
4a791a9
optimize batch UCB function, simplify batch selection
Nov 17, 2025
913b330
Merge branch 'ucb_pattern_search' of https://github.com/ethche/helion…
Nov 17, 2025
047509d
batch optimization by default
Nov 17, 2025
eb430ce
LFBO instead of ucb_pattern_search
Nov 18, 2025
33803df
LFBO tests
Nov 18, 2025
b6c24cc
LFBO better docstring
Nov 18, 2025
e3106a8
LFBO remove patience feature
Nov 18, 2025
1c810f0
LFBO imports
Nov 18, 2025
c25362d
Fix comments
Nov 18, 2025
b5a65be
Fix test names
Nov 18, 2025
e30bdad
Merge branch 'main' into ucb_pattern_search
ethche Nov 18, 2025
062be0c
remove comma
Nov 18, 2025
fcb070d
Fix comments
Nov 18, 2025
1572172
Fix comments
Nov 18, 2025
b6b191e
better lfbo hyperparams
Nov 19, 2025
4301669
rename to surrogate
Nov 20, 2025
eef301f
fix linter error for candidates
Nov 20, 2025
5bf793f
Merge branch 'main' into ucb_pattern_search
ethche Nov 20, 2025
de6fd3e
lower case train x train y
Nov 20, 2025
f534766
remove is_categorical
Nov 20, 2025
edf6313
no shape catching for scores
Nov 20, 2025
b3d471b
test remove ignores for linter jobs
Nov 20, 2025
5f600b3
test that dim matches length of encoded value
Nov 20, 2025
deba301
update linter to install surrogate dependencies
Nov 20, 2025
1670416
Merge branch 'ucb_pattern_search' of https://github.com/ethche/helion…
Nov 20, 2025
bc0de49
remove another type ignore
Nov 20, 2025
bc8755e
Merge branch 'main' into ucb_pattern_search
ethche Nov 20, 2025
5373a26
add patience and increase init pop
Nov 21, 2025
22ba0ad
set lfbo to be autotuner for benchmark ci
Nov 21, 2025
d87a05a
quantile among finite, fixes shape error
Nov 21, 2025
172ed3d
better generalization for surrogate
Nov 21, 2025
5bc1cd3
Merge branch 'main' into ucb_pattern_search
ethche Nov 21, 2025
9ddf6d8
revert env-vars
Nov 21, 2025
8db5ecb
Merge branch 'ucb_pattern_search' of https://github.com/ethche/helion…
Nov 21, 2025
2b1ff4b
fix docstring
Nov 21, 2025
2112052
debug msg
Nov 21, 2025
a498da5
patience
Nov 21, 2025
7cb0746
restore init
Nov 21, 2025
c813772
patience
Nov 21, 2025
5f66cc2
patience
Nov 21, 2025
7dc6524
remove score logging
Nov 21, 2025
a4fc388
Merge branch 'main' into ucb_pattern_search
ethche Nov 21, 2025
960e526
better docs for _fit_surrogate
ethche Nov 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions helion/autotuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .fragment_encoder import ConfigEncoder
from .config_fragment import BooleanFragment as BooleanFragment
from .config_fragment import EnumFragment as EnumFragment
from .config_fragment import IntegerFragment as IntegerFragment
Expand All @@ -17,6 +18,7 @@
from .local_cache import LocalAutotuneCache as LocalAutotuneCache
from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache
from .pattern_search import PatternSearch as PatternSearch
from .ucb_pattern_search import UCBPatternSearch
from .random_search import RandomSearch as RandomSearch

search_algorithms = {
Expand Down
266 changes: 266 additions & 0 deletions helion/autotuner/fragment_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""Fragment encoding/decoding strategies for machine learning based autotuners.

This module provides a clean abstraction for encoding different fragment types
into numerical tensors and decoding them back. Each fragment type has its own
encoder that knows how to analyze, encode, and decode itself.
"""

from __future__ import annotations

from abc import ABC
from abc import abstractmethod
import math

from .config_fragment import BooleanFragment
from .config_fragment import ConfigSpecFragment
from .config_fragment import EnumFragment
from .config_fragment import IntegerFragment
from .config_fragment import ListOf
from .config_fragment import PermutationFragment
from .config_fragment import PowerOfTwoFragment


class FragmentEncoder(ABC):
"""Base class for encoding/decoding fragment values."""

def __init__(self, fragment: ConfigSpecFragment) -> None:
self.fragment = fragment

@abstractmethod
def n_dims(self) -> int:
"""Return the number of dimensions this fragment uses in encoded space."""

@abstractmethod
def is_categorical(self) -> bool:
"""Return whether this fragment represents categorical data."""

@abstractmethod
def encode(self, value: object) -> list[float]:
"""Encode a value into a list of floats."""

@abstractmethod
def decode(self, encoded: list[float]) -> object:
"""Decode a list of floats back to the original value type."""


class CategoricalEncoder(FragmentEncoder):
"""Encoder for EnumFragment and BooleanFragment using one-hot encoding."""

def __init__(
self, fragment: EnumFragment | BooleanFragment, choices: list[object]
) -> None:
super().__init__(fragment)
self.choices = choices

def n_dims(self) -> int:
return len(self.choices)

def is_categorical(self) -> bool:
return True

def encode(self, value: object) -> list[float]:
idx = self.choices.index(value)
return [1.0 if i == idx else 0.0 for i in range(len(self.choices))]

def decode(self, encoded: list[float]) -> object:
choice_idx = max(range(len(self.choices)), key=lambda i: encoded[i])
return self.choices[choice_idx]


class PowerOfTwoEncoder(FragmentEncoder):
"""Encoder for PowerOfTwoFragment using log2 transformation."""

def __init__(self, fragment: PowerOfTwoFragment) -> None:
super().__init__(fragment)
self.log_min = math.log2(fragment.low)
self.log_max = math.log2(fragment.high)

def n_dims(self) -> int:
return 1

def is_categorical(self) -> bool:
return False

def encode(self, value: int) -> list[float]:
return [math.log2(value)]

def decode(self, encoded: list[float]) -> int:
log_val = encoded[0]
power = int(round(log_val))
power = max(int(self.log_min), min(power, int(self.log_max)))
return 2**power


class IntegerEncoder(FragmentEncoder):
"""Encoder for IntegerFragment using raw values."""

def __init__(self, fragment: IntegerFragment) -> None:
super().__init__(fragment)
self.min_val = fragment.low
self.max_val = fragment.high

def n_dims(self) -> int:
return 1

def is_categorical(self) -> bool:
return False

def encode(self, value: object) -> list[float]:
return [float(value)]

def decode(self, encoded: list[float]) -> int:
value = int(round(encoded[0]))
return max(self.min_val, min(value, self.max_val))


class PermutationEncoder(FragmentEncoder):
"""Encoder for PermutationFragment using one-hot encoding for each position."""

def __init__(self, fragment: PermutationFragment) -> None:
super().__init__(fragment)
self.length = fragment.length

def n_dims(self) -> int:
return self.length * self.length

def is_categorical(self) -> bool:
return True

def encode(self, value: list[int]) -> list[float]:
encoded = []
for pos in range(self.length):
val = value[pos]
for v in range(self.length):
encoded.append(1.0 if v == val else 0.0)
return encoded

def decode(self, encoded: list[float]) -> list[int]:
perm = []
used = set()

for pos in range(self.length):
start_idx = pos * self.length
one_hot = encoded[start_idx : start_idx + self.length]
val = max(range(self.length), key=lambda i: one_hot[i])
perm.append(val)
used.add(val)

# Fix invalid permutation (duplicates/missing values)
if len(used) != self.length:
available = [v for v in range(self.length) if v not in used]
seen = set()
fixed_perm = []
for val in perm:
if val in seen:
fixed_val = available.pop(0)
fixed_perm.append(fixed_val)
else:
fixed_perm.append(val)
seen.add(val)
return fixed_perm

return perm


class ListOfEncoder(FragmentEncoder):
"""Encoder for ListOf fragments, delegates to inner encoder."""

def __init__(self, fragment: ListOf, inner_encoder: FragmentEncoder) -> None:
super().__init__(fragment)
self.length = fragment.length
self.inner_encoder = inner_encoder
self.inner_dims = inner_encoder.n_dims()

def n_dims(self) -> int:
return self.length * self.inner_dims

def is_categorical(self) -> bool:
"""Return True if the inner encoder is categorical."""
return self.inner_encoder.is_categorical()

def encode(self, value: list[object]) -> list[float]:
encoded = []
for v in value:
encoded.extend(self.inner_encoder.encode(v))
return encoded

def decode(self, encoded: list[float]) -> list[object]:
decoded = []
for i in range(self.length):
start_idx = i * self.inner_dims
element_encoded = encoded[start_idx : start_idx + self.inner_dims]
decoded.append(self.inner_encoder.decode(element_encoded))
return decoded


def create_encoder(fragment: ConfigSpecFragment) -> FragmentEncoder:
"""Factory function to create the appropriate encoder for a fragment."""
if isinstance(fragment, BooleanFragment):
return CategoricalEncoder(fragment, [False, True])
if isinstance(fragment, EnumFragment):
return CategoricalEncoder(fragment, list(fragment.choices))
if isinstance(fragment, PowerOfTwoFragment):
return PowerOfTwoEncoder(fragment)
if isinstance(fragment, IntegerFragment):
return IntegerEncoder(fragment)
if isinstance(fragment, PermutationFragment):
return PermutationEncoder(fragment)
if isinstance(fragment, ListOf):
inner_encoder = create_encoder(fragment.inner)
return ListOfEncoder(fragment, inner_encoder)
raise ValueError(f"Unsupported fragment type: {type(fragment).__name__}")


class ConfigEncoder:
"""Encodes and decodes entire configurations using fragment encoders."""

def __init__(self, flat_spec: list[ConfigSpecFragment]) -> None:
"""Initialize encoders for all fragments in the spec.

Args:
flat_spec: List of fragment specifications
"""
self.encoders = [create_encoder(fragment) for fragment in flat_spec]
self.total_dims = sum(encoder.n_dims() for encoder in self.encoders)

# Build categorical dimension indices (absolute positions)
self.cat_dims = []
offset = 0
for encoder in self.encoders:
n_dims = encoder.n_dims()
if encoder.is_categorical():
# All dimensions of this encoder are categorical
self.cat_dims.extend(range(offset, offset + n_dims))
offset += n_dims

def encode(self, flat_config: list[object]) -> list[float]:
"""Encode a flat configuration into a list of floats.

Args:
flat_config: List of configuration values

Returns:
List of encoded float values
"""
encoded = []
for value, encoder in zip(flat_config, self.encoders, strict=False):
encoded.extend(encoder.encode(value))
return encoded

def decode(self, encoded: list[float]) -> list[object]:
"""Decode a list of floats back into a flat configuration.

Args:
encoded: List of encoded float values

Returns:
List of decoded configuration values
"""
decoded = []
idx = 0
for encoder in self.encoders:
n_dims = encoder.n_dims()
fragment_encoded = encoded[idx : idx + n_dims]
decoded.append(encoder.decode(fragment_encoded))
idx += n_dims
return decoded
Loading
Loading