diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 47ac23bdd..4d57f8db6 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -100,7 +100,7 @@ jobs: - name: Install Helion run: | source .venv/bin/activate - SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]' + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,surrogate]' python -c "import helion; print(helion.__name__)" - name: Install Benchmark Requirements diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f4be7527d..2e5abaabb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -146,7 +146,7 @@ jobs: run: | source .venv/bin/activate uv pip install setuptools ninja - SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]' + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,surrogate]' python -c "import helion; print(helion.__name__)" - name: Run Tests diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 52a0c672e..7d08e9603 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -19,9 +19,11 @@ from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache from .pattern_search import PatternSearch as PatternSearch from .random_search import RandomSearch as RandomSearch +from .surrogate_pattern_search import LFBOPatternSearch search_algorithms = { "DESurrogateHybrid": DESurrogateHybrid, + "LFBOPatternSearch": LFBOPatternSearch, "DifferentialEvolutionSearch": DifferentialEvolutionSearch, "FiniteSearch": FiniteSearch, "PatternSearch": PatternSearch, diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index c58df1265..1dcc84d71 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -51,15 +51,18 @@ def differential_mutation(self, a: object, b: object, c: object) -> object: def is_block_size(self) -> bool: return False - def get_minimum(self) -> int: + def is_categorical(self) -> bool: + return True + + def dim(self) -> int: """ - Return the minimum allowed value for this fragment. + Returns the dimension of the output of encode """ raise NotImplementedError - def encode_scalar(self, value: object) -> float: + def encode(self, value: object) -> list[float]: """ - Encode a configuration value into a float for ML models. + Encode a configuration value into a list of floats for ML models. This is used by surrogate-assisted algorithms to convert configurations into numerical vectors for prediction models. @@ -68,14 +71,15 @@ def encode_scalar(self, value: object) -> float: value: The configuration value to encode. Returns: - A float representing the encoded value. + A list of floats representing the encoded value. """ - # Default: convert to float if possible - if not isinstance(value, (int, float, bool)): - raise TypeError( - f"Cannot encode {type(value).__name__} value {value!r} for ML" - ) - return float(value) + raise NotImplementedError + + def get_minimum(self) -> int: + """ + Return the minimum allowed value for this fragment. + """ + raise NotImplementedError @dataclasses.dataclass @@ -106,6 +110,17 @@ def pattern_neighbors(self, current: object) -> list[object]: neighbors.append(swapped) return neighbors + def dim(self) -> int: + return self.length + + def encode(self, value: object) -> list[float]: + assert isinstance(value, list) + encoded = [] + for val in value: + assert isinstance(val, int) + encoded.append(float(val)) + return value + @dataclasses.dataclass class BaseIntegerFragment(ConfigSpecFragment): @@ -126,9 +141,15 @@ def default(self) -> int: def clamp(self, val: int) -> int: return max(min(val, self.high), self.low) + def is_categorical(self) -> bool: + return False + def get_minimum(self) -> int: return self.low + def dim(self) -> int: + return 1 + def pattern_neighbors(self, current: object) -> list[object]: if type(current) is not int: # bool is not allowed raise TypeError(f"Expected int, got {type(current).__name__}") @@ -141,13 +162,9 @@ def pattern_neighbors(self, current: object) -> list[object]: neighbors.append(upper) return neighbors - def encode_scalar(self, value: object) -> float: - """Encode integer values directly as floats.""" - if not isinstance(value, (int, float)): - raise TypeError( - f"Expected int/float for BaseIntegerFragment, got {type(value).__name__}: {value!r}" - ) - return float(value) + def encode(self, value: object) -> list[float]: + assert isinstance(value, int) + return [float(value)] class PowerOfTwoFragment(BaseIntegerFragment): @@ -180,7 +197,7 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: return self.clamp(ai * 2) return ai - def encode_scalar(self, value: object) -> float: + def encode(self, value: object) -> list[float]: """Encode power-of-2 values using log2 transformation.""" import math @@ -192,7 +209,7 @@ def encode_scalar(self, value: object) -> float: raise ValueError( f"Expected positive value for PowerOfTwoFragment, got {value}" ) - return math.log2(float(value)) + return [math.log2(float(value))] class IntegerFragment(BaseIntegerFragment): @@ -235,7 +252,10 @@ def differential_mutation(self, a: object, b: object, c: object) -> object: choices.remove(a) return random.choice(choices) - def encode_scalar(self, value: object) -> float: + def dim(self) -> int: + return len(self.choices) + + def encode(self, value: object) -> list[float]: """Encode enum values as their index.""" try: choice_idx = self.choices.index(value) @@ -244,7 +264,7 @@ def encode_scalar(self, value: object) -> float: f"Invalid enum value {value!r} for EnumFragment. " f"Valid choices: {self.choices}" ) from None - return float(choice_idx) + return [1.0 if i == choice_idx else 0.0 for i in range(len(self.choices))] class BooleanFragment(ConfigSpecFragment): @@ -265,6 +285,14 @@ def differential_mutation(self, a: object, b: object, c: object) -> bool: return a return not a + def dim(self) -> int: + return 1 + + def encode(self, value: object) -> list[float]: + """Encode enum values as their index.""" + assert isinstance(value, bool) + return [1.0] if value else [0.0] + class BlockSizeFragment(PowerOfTwoFragment): def category(self) -> Category: @@ -296,6 +324,9 @@ def random(self) -> list[object]: """Return a list of random values.""" return [self.inner.random() for _ in range(self.length)] + def is_categorical(self) -> bool: + return self.inner.is_categorical() + def pattern_neighbors(self, current: object) -> list[object]: """Return neighbors by changing one element at a time.""" if not isinstance(current, list) or len(current) != self.length: @@ -320,3 +351,13 @@ def differential_mutation(self, a: object, b: object, c: object) -> list[object] self.inner.differential_mutation(a[i], b[i], c[i]) for i in range(self.length) ] + + def dim(self) -> int: + return self.length * self.inner.dim() + + def encode(self, value: object) -> list[float]: + assert isinstance(value, list) + encoded = [] + for v in value: + encoded.extend(self.inner.encode(v)) + return encoded diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index 0747f41a9..9e1533be4 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -199,6 +199,6 @@ def encode_config(self, flat_config: FlatConfig) -> list[float]: for flat_idx, spec in enumerate(self.flat_spec): value = flat_config[flat_idx] - encoded.append(spec.encode_scalar(value)) + encoded.extend(spec.encode(value)) return encoded diff --git a/helion/autotuner/de_surrogate_hybrid.py b/helion/autotuner/de_surrogate_hybrid.py index b04cbfb9b..e6a36df5e 100644 --- a/helion/autotuner/de_surrogate_hybrid.py +++ b/helion/autotuner/de_surrogate_hybrid.py @@ -91,7 +91,7 @@ def __init__( if not HAS_ML_DEPS: raise ImportError( "DESurrogateHybrid requires numpy and scikit-learn. " - "Install them with: pip install helion[de-surrogate]" + "Install them with: pip install helion[surrogate]" ) # Initialize parent with early stopping parameters diff --git a/helion/autotuner/pattern_search.py b/helion/autotuner/pattern_search.py index d8d759b31..6e36ea36d 100644 --- a/helion/autotuner/pattern_search.py +++ b/helion/autotuner/pattern_search.py @@ -134,20 +134,35 @@ def _pattern_search_from( if len(candidates) <= 1: return # no new candidates, stop searching yield candidates # yield new population to benchmark in parallel + # update search copy and check early stopping criteria best = min(candidates, key=performance) - if best is current: - return # no improvement, stop searching - # Stop if the relative improvement is smaller than a user-specified delta - if ( - self.min_improvement_delta > 0.0 - and math.isfinite(best.perf) - and math.isfinite(current.perf) - and current.perf != 0.0 - and abs(best.perf / current.perf - 1.0) < self.min_improvement_delta - ): + if self._check_early_stopping(best, current): return current = best + def _check_early_stopping( + self, best: PopulationMember, current: PopulationMember + ) -> bool: + """ + Check if early stopping criteria are met for the search copy + + Early stops if either the best config has not changed or if + the relative improvement is smaller than a user-specified delta + + Returns: + True the search copy is terminated, False otherwise. + """ + if best is current: + return True # no improvement, stop searching + # Stop if the relative improvement is smaller than a user-specified delta + return bool( + self.min_improvement_delta > 0.0 + and math.isfinite(best.perf) + and math.isfinite(current.perf) + and current.perf != 0.0 + and abs(best.perf / current.perf - 1.0) < self.min_improvement_delta + ) + def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: """ Generate neighboring configurations by changing one or two parameters at a time. diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py new file mode 100644 index 000000000..afc9b6865 --- /dev/null +++ b/helion/autotuner/surrogate_pattern_search.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +import math +import operator +import random +from typing import TYPE_CHECKING + +from .. import exc +from .base_search import FlatConfig +from .base_search import PopulationMember +from .base_search import performance +from .config_fragment import PowerOfTwoFragment +from .effort_profile import PATTERN_SEARCH_DEFAULTS +from .pattern_search import PatternSearch + +if TYPE_CHECKING: + from collections.abc import Iterator + from collections.abc import Sequence + + from ..runtime.config import Config + from ..runtime.kernel import BoundKernel + +try: + import numpy as np # type: ignore[import-not-found] + from sklearn.ensemble import ( # type: ignore[import-not-found] + RandomForestClassifier, + ) + + HAS_ML_DEPS = True +except ImportError as e: + HAS_ML_DEPS = False + _IMPORT_ERROR = e + + +class LFBOPatternSearch(PatternSearch): + """ + Likelihood-Free Bayesian Optimization (LFBO) Pattern Search. + + This algorithm enhances PatternSearch by using a Random Forest classifier as a surrogate + model to select which configurations to benchmark, reducing the number of + kernel compilations and runs needed to find optimal configurations. + + Algorithm Overview: + 1. Generate an initial random population and benchmark all configurations + 2. Fit a Random Forest classifier to predict "good" vs "bad" configurations: + - Configs with performance < quantile threshold are labeled as "good" (class 1) + - Configs with performance >= quantile threshold are labeled as "bad" (class 0) + - Weighted classification emphasize configs that are much better than the threshold + 3. For each generation: + - Generate random neighbors around the current best configurations + - Score all neighbors using the classifier's predicted probability of being "good" + - Benchmark only the top frac_selected fraction of neighbors + - Retrain the classifier on all observed data (not incremental) + - Update search trajectories based on new results + + The weighted classification model learns to identify which configs maximize + expected improvement over the current best config. Compared to fitting a surrogate + to fit the config performances themselves, since this method is based on classification, + it can also learn from configs that timeout or have unacceptable accuracy. + + References: + - Song, J., et al. (2022). "A General Recipe for Likelihood-free Bayesian Optimization." + + Args: + kernel: The kernel to be autotuned. + args: The arguments to be passed to the kernel during benchmarking. + initial_population: Number of random configurations in initial population. + Default from PATTERN_SEARCH_DEFAULTS. + copies: Number of top configurations to run pattern search from. + Default from PATTERN_SEARCH_DEFAULTS. + max_generations: Maximum number of search iterations per copy. + Default from PATTERN_SEARCH_DEFAULTS. + min_improvement_delta: Early stopping threshold. Search stops if the relative + improvement abs(best/current - 1) < min_improvement_delta. + Default: 0.001 (0.1% improvement threshold). + frac_selected: Fraction of generated neighbors to actually benchmark, after + filtering by classifier score. Range: (0, 1]. Lower values reduce benchmarking + cost but may miss good configurations. Default: 0.15. + num_neighbors: Number of random neighbor configurations to generate around + each search point per generation. Default: 300. + radius: Maximum perturbation distance in configuration space. For power-of-two + parameters, this is the max change in log2 space. For other parameters, + this limits how many parameters can be changed. Default: 2. + quantile: Threshold for labeling configs as "good" (class 1) vs "bad" (class 0). + Configs with performance below this quantile are labeled as good. + Range: (0, 1). Lower values create a more selective definition of "good". + Default: 0.3 (top 30% are considered good). + """ + + def __init__( + self, + kernel: BoundKernel, + args: Sequence[object], + *, + initial_population: int = PATTERN_SEARCH_DEFAULTS.initial_population, + copies: int = PATTERN_SEARCH_DEFAULTS.copies, + max_generations: int = PATTERN_SEARCH_DEFAULTS.max_generations, + min_improvement_delta: float = 0.001, + frac_selected: float = 0.4, + num_neighbors: int = 100, + radius: int = 2, + quantile: float = 0.3, + ) -> None: + if not HAS_ML_DEPS: + raise exc.AutotuneError( + "LFBOPatternSearch requires numpy and scikit-learn." + "Install them with: pip install helion[surrogate]" + ) from _IMPORT_ERROR + + super().__init__( + kernel=kernel, + args=args, + initial_population=initial_population, + copies=copies, + max_generations=max_generations, + min_improvement_delta=min_improvement_delta, + ) + + # Number of neighbors and how many to evalaute + self.num_neighbors = num_neighbors + self.radius = radius + self.frac_selected = frac_selected + + # Save training data + self.train_X = [] + self.train_Y = [] + self.quantile = quantile + self.surrogate = None + + def _fit_surrogate(self) -> None: + train_X = np.array(self.train_X) # type: ignore[union-attr] + train_Y = np.array(self.train_Y) # type: ignore[union-attr] + self.log.debug( + f"Fitting surrogate: {len(train_X)} points, {len(train_Y)} targets" + ) + train_Y_quantile = np.quantile(train_Y, self.quantile) # type: ignore[union-attr] + + # Labels are generated by which are configs better than the quantile + train_labels = 1.0 * (train_Y < train_Y_quantile) + pos_weights = np.maximum(0, train_Y_quantile - train_Y) # type: ignore[union-attr] + normalizing_factor = np.mean( # type: ignore[union-attr] + np.array([weight for weight in pos_weights if weight > 0.0]) # type: ignore[union-attr] + ) + pos_weights = pos_weights / normalizing_factor + sample_weight = np.where(train_Y < train_Y_quantile, pos_weights, 1.0) # type: ignore[union-attr] + + self.surrogate = RandomForestClassifier( # type: ignore[misc] + criterion="log_loss", + random_state=42, + n_estimators=100, + min_samples_split=2, + min_samples_leaf=1, + n_jobs=-1, + ) + self.surrogate.fit(train_X, train_labels, sample_weight=sample_weight) + + def _surrogate_select( + self, candidates: list[PopulationMember], n_sorted: int + ) -> list[PopulationMember]: + # Score candidates + candidate_X = np.array( # type: ignore[union-attr] + [self.config_gen.encode_config(member.flat_values) for member in candidates] + ) + scores = self.surrogate.predict_proba(candidate_X) # type: ignore[assignment] + + if scores.shape[1] == 2: # type: ignore[union-attr] + scores = scores[:, 1] # type: ignore[index] + elif scores.shape[1] == 1: # type: ignore[union-attr] + # If probabilities are all 1, then the model outputs a 1D vector. + scores = scores[:, 0] # type: ignore[index] + else: + raise ValueError("Unexpected shape for scores") + + # sort candidates by score + candidates_sorted = sorted( + zip(candidates, scores, strict=True), + key=operator.itemgetter(1), + reverse=True, # higher scores are better + )[:n_sorted] + + self.log.debug( + f"Scoring {len(candidate_X)} neighbors, selecting {(n_sorted / len(candidate_X)) * 100:.0f}% neighbors: {len(candidates_sorted)}" + ) + + return [member for member, score in candidates_sorted] + + def _autotune(self) -> Config: + self.log( + f"Starting LFBOPatternSearch with initial_population={self.initial_population}, copies={self.copies}, max_generations={self.max_generations}" + ) + visited = set() + self.population = [] + for flat_config in self.config_gen.random_population_flat( + self.initial_population + ): + member = self.make_unbenchmarked(flat_config) + if member.config not in visited: + visited.add(member.config) + self.population.append(member) + self.parallel_benchmark_population(self.population, desc="Initial population") + # again with higher accuracy + self.rebenchmark_population(self.population, desc="Verifying initial results") + self.population.sort(key=performance) + starting_points = [] + for member in self.population[: self.copies]: + if math.isfinite(member.perf): # filter failed compiles + starting_points.append(member) + self.log( + f"Initial random population of {len(self.population)}, {len(starting_points)} starting points:", + self.statistics, + ) + if not starting_points: + raise exc.NoConfigFound + + # Save to training data + for member in self.population: + self.train_X.append(self.config_gen.encode_config(member.flat_values)) + self.train_Y.append(member.perf) + + # Fit model + self._fit_surrogate() + + search_copies = [ + self._pruned_pattern_search_from(m, visited) for m in starting_points + ] + for generation in range(1, self.max_generations + 1): + prior_best = self.best + new_population = {id(prior_best): prior_best} + num_neighbors = 0 + num_active = 0 + for search_copy in search_copies: + added = next(search_copy, ()) + if added: + assert len(added) > 1 + num_active += 1 + num_neighbors += len(added) - 1 + for member in added: + new_population[id(member)] = member + if num_active == 0: + break + + # Log generation header before compiling/benchmarking + self.log( + f"Generation {generation} starting: {num_neighbors} neighbors, {num_active} active search path(s)" + ) + + self.population = [*new_population.values()] + # compile any unbenchmarked members in parallel + unbenchmarked = [m for m in self.population if len(m.perfs) == 0] + if unbenchmarked: + self.parallel_benchmark_population( + unbenchmarked, desc=f"Generation {generation}:" + ) + # higher-accuracy rebenchmark + self.rebenchmark_population( + self.population, desc=f"Generation {generation}: verifying top configs" + ) + # Log final statistics for this generation + self.log(f"Generation {generation} complete:", self.statistics) + + # Update training data + for member in self.population: + self.train_X.append(self.config_gen.encode_config(member.flat_values)) + self.train_Y.append(member.perf) + + # Fit model + self._fit_surrogate() + + return self.best.config + + def _random_log2_neighbor( + self, current_val: int, radius: int, low: int, high: int + ) -> int: + # Log the current value + current_log = int(math.log2(current_val)) + # Random log perturbation + delta = random.randint(-radius, radius) + new_log = current_log + delta + # Clamp to valid range + min_log = int(math.log2(low)) + max_log = int(math.log2(high)) + new_log = max(min_log, min(new_log, max_log)) + return int(2**new_log) + + def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: + """ + Generate neighboring configurations randomly within a specified radius. + + Strategy: + 1. Sample one block size index and change it by at most radius (in log2 space) + 2. Sample the num_warps index and change it by at most radius (in log2 space) + 3. For at most radius remaining indices, randomly select pattern neighbors + + Args: + base: The base configuration to generate neighbors from + + Returns: + A list of neighboring configurations + """ + neighbors: list[FlatConfig] = [] + + # Generate num_neighbors random neighbors + for _ in range(self.num_neighbors): + new_flat = [*base] # Copy the base configuration + modified_indices = set() + + # 1. Sample a block size index and change it by at most radius + if self.config_gen.block_size_indices: + block_idx = random.choice(self.config_gen.block_size_indices) + modified_indices.add(block_idx) + + block_spec = self.config_gen.flat_spec[block_idx] + current_val = base[block_idx] + assert isinstance(current_val, int) + + if isinstance(block_spec, PowerOfTwoFragment): + # Change by at most radius in log2 space + new_flat[block_idx] = self._random_log2_neighbor( + current_val, + radius=self.radius, + low=block_spec.low, + high=block_spec.high, + ) + else: + raise ValueError("BlockSize should be PowerOfTwoFragment") + + # 2. Sample the num_warps index and change it by at most radius + if self.config_gen.num_warps_index: + warp_idx = self.config_gen.num_warps_index + modified_indices.add(warp_idx) + + warp_spec = self.config_gen.flat_spec[warp_idx] + current_val = base[warp_idx] + assert isinstance(current_val, int) + + if isinstance(warp_spec, PowerOfTwoFragment): + # Change by at most self.radius in log2 space + new_flat[warp_idx] = self._random_log2_neighbor( + current_val, + radius=self.radius, + low=warp_spec.low, + high=warp_spec.high, + ) + else: + raise ValueError("NumWarps should be PowerOfTwoFragment") + + # 3. For at most radius remaining indices, use pattern neighbors + # Exclude the already-modified block size and warp indices + + # Collect available pattern neighbors for remaining indices + remaining_pattern_neighbors = [] + for index, spec in enumerate(self.config_gen.flat_spec): + if index not in modified_indices: + pattern_neighbors = spec.pattern_neighbors(base[index]) + if pattern_neighbors: + remaining_pattern_neighbors.append((index, pattern_neighbors)) + + # Randomly select at most radius indices to change + if remaining_pattern_neighbors: + num_to_change = random.randint( + 0, min(self.radius, len(remaining_pattern_neighbors)) + ) + if num_to_change > 0: + indices_to_change = random.sample( + remaining_pattern_neighbors, num_to_change + ) + for idx, pattern_neighbors in indices_to_change: + new_flat[idx] = random.choice(pattern_neighbors) + + # Only add if it's different from the base + if new_flat != base: + neighbors.append(new_flat) + + return neighbors + + def _pruned_pattern_search_from( + self, + current: PopulationMember, + visited: set[Config], + ) -> Iterator[list[PopulationMember]]: + """ + Run a single copy of pattern search from the given starting point. + + We use a generator and yield the new population at each generation so that we can + run multiple copies of pattern search in parallel. + + Only keep self.frac_selected of the neighbors generated from the current + search_copy using _surrogate_select. + + Args: + current: The current best configuration. + visited: A set of visited configurations. + + Returns: + A generator that yields the new population at each generation. + """ + for _ in range(self.max_generations): + candidates = [current] + all_neighbors = self._generate_neighbors(current.flat_values) + for flat_config in all_neighbors: + new_member = self.make_unbenchmarked(flat_config) + if new_member.config not in visited: + candidates.append(new_member) + visited.add(new_member.config) + + # score candidates + n_sorted = int(len(candidates) * self.frac_selected) + candidates = self._surrogate_select(candidates, n_sorted) + + if len(candidates) <= 1: + return # no new candidates, stop searching + yield candidates # yield new population to benchmark in parallel + best = min(candidates, key=performance) + if self._check_early_stopping(best, current): + return + current = best diff --git a/pyproject.toml b/pyproject.toml index bc76c5127..e02adb30c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ ] [project.optional-dependencies] -de-surrogate = [ +surrogate = [ "numpy", "scikit-learn>=1.3.0" ] diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 7c6fcd4d6..d7f1258df 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -35,6 +35,7 @@ from helion._testing import skipIfRocm from helion.autotuner import DESurrogateHybrid from helion.autotuner import DifferentialEvolutionSearch +from helion.autotuner import LFBOPatternSearch from helion.autotuner import PatternSearch from helion.autotuner.base_search import BaseSearch from helion.autotuner.base_search import PopulationMember @@ -42,6 +43,7 @@ from helion.autotuner.config_fragment import EnumFragment from helion.autotuner.config_fragment import IntegerFragment from helion.autotuner.config_fragment import ListOf +from helion.autotuner.config_fragment import PermutationFragment from helion.autotuner.config_fragment import PowerOfTwoFragment from helion.autotuner.config_generation import ConfigGeneration from helion.autotuner.effort_profile import get_effort_profile @@ -622,6 +624,67 @@ def diff_count(flat): ] self.assertEqual(sorted(pair_neighbors), sorted(expected)) + def test_lfbo_pattern_search_generate_neighbors(self): + """Test LFBOPatternSearch._generate_neighbors method.""" + random.seed(123) + search = LFBOPatternSearch.__new__(LFBOPatternSearch) + search.num_neighbors = 50 + search.radius = 2 + search.config_gen = SimpleNamespace( + flat_spec=[ + PowerOfTwoFragment(16, 128, 32), # block_size[0] + PowerOfTwoFragment(16, 128, 64), # block_size[1] + PowerOfTwoFragment(2, 16, 4), # num_warps + EnumFragment(("a", "b", "c")), # some enum + BooleanFragment(), # some boolean + ], + block_size_indices=[0, 1], + num_warps_index=2, + ) + + base = [32, 64, 4, "b", True] + neighbors = search._generate_neighbors(base) + + # Check we generate the correct number of neighbors + self.assertEqual(len(neighbors), search.num_neighbors) + + # Check all neighbors are different from base + for neighbor in neighbors: + self.assertNotEqual(neighbor, base) + + # Verify all block sizes are valid powers of two in range + for neighbor in neighbors: + # Check block_size[0] + self.assertIn(neighbor[0], [16, 32, 64, 128]) + # Check block_size[1] + self.assertIn(neighbor[1], [16, 32, 64, 128]) + # Check num_warps + self.assertIn(neighbor[2], [2, 4, 8, 16]) + # Check enum + self.assertIn(neighbor[3], ["a", "b", "c"]) + # Check boolean + self.assertIn(neighbor[4], [True, False]) + + @skipIfRocm("too slow on rocm") + @skip("too slow") + def test_lfbo_pattern_search(self): + args = ( + torch.randn([64, 64], device=DEVICE), + torch.randn([64, 64], device=DEVICE), + ) + bound_kernel = basic_kernels.add.bind(args) + random.seed(123) + best = LFBOPatternSearch( + bound_kernel, + args, + initial_population=10, + max_generations=2, + copies=1, + num_neighbors=10, + ).autotune() + fn = bound_kernel.compile_config(best) + torch.testing.assert_close(fn(*args), sum(args), rtol=1e-2, atol=1e-1) + @skipIfCpu("fails on Triton CPU backend") def test_accuracy_check_filters_bad_config_wrong_output(self) -> None: bad_config = helion.Config(block_sizes=[1], num_warps=8) @@ -1062,6 +1125,53 @@ def add(a, b): ): add(*args) + def test_fragment_encoding(self): + """Test encoding functionality for all ConfigSpecFragment types.""" + # Test BooleanFragment + bool_frag = BooleanFragment() + self.assertEqual(bool_frag.dim(), 1) + self.assertEqual(bool_frag.encode(True), [1.0]) + self.assertEqual(bool_frag.encode(False), [0.0]) + + # Test IntegerFragment + int_frag = IntegerFragment(low=1, high=10, default_val=5) + self.assertEqual(int_frag.dim(), 1) + self.assertEqual(int_frag.encode(5), [5.0]) + + # Test PowerOfTwoFragment (log2 transformation) + pow2_frag = PowerOfTwoFragment(low=2, high=128, default_val=8) + self.assertEqual(pow2_frag.dim(), 1) + self.assertEqual(pow2_frag.encode(8), [3.0]) # log2(8) = 3 + self.assertEqual(pow2_frag.encode(16), [4.0]) # log2(16) = 4 + + # Test EnumFragment (one-hot encoding) + enum_frag = EnumFragment(choices=("a", "b", "c")) + self.assertEqual(enum_frag.dim(), 3) + self.assertEqual(enum_frag.encode("a"), [1.0, 0.0, 0.0]) + self.assertEqual(enum_frag.encode("b"), [0.0, 1.0, 0.0]) + + # Test PermutationFragment + perm_frag = PermutationFragment(length=3) + self.assertEqual(perm_frag.dim(), 3) + encoded = perm_frag.encode([0, 1, 2]) + self.assertEqual(encoded, [0, 1, 2]) + + # Test ListOf with BooleanFragment + list_frag = ListOf(inner=BooleanFragment(), length=3) + self.assertEqual(list_frag.dim(), 3) + self.assertEqual(list_frag.encode([True, False, True]), [1.0, 0.0, 1.0]) + + # Test encode_dim consistency + for fragment, value in [ + (BooleanFragment(), True), + (IntegerFragment(1, 10, 5), 5), + (PowerOfTwoFragment(2, 128, 8), 16), + (EnumFragment(choices=("a", "b")), "b"), + ]: + dim = fragment.dim() + encoded = fragment.encode(value) + self.assertEqual(len(encoded), dim) + class TestAutotuneRandomSeed(RefEagerTestDisabled, TestCase): def _autotune_and_record(self, **settings: object) -> float: