-
Notifications
You must be signed in to change notification settings - Fork 68
Add LFBO Pattern Search #1115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add LFBO Pattern Search #1115
Changes from 8 commits
9c3e56a
8035c8e
e24b529
e5be414
140b3b2
ca99252
899d30d
c6fc4cc
13ea3b4
9cf7dbe
3b739f5
62aaf77
97963f5
a0ef224
018a626
2a65701
79c0fa3
c2a578f
fc2929d
44c925d
70a46fe
f837953
2c2eec9
74b0754
d9cce1e
3437347
4a791a9
913b330
047509d
eb430ce
33803df
b6c24cc
e3106a8
1c810f0
c25362d
b5a65be
e30bdad
062be0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,9 +3,7 @@ | |
| import dataclasses | ||
| import enum | ||
| import random | ||
| from typing import Iterable | ||
| from typing import TypeGuard | ||
| from typing import cast | ||
| from typing import cast, Iterable, TypeGuard | ||
|
|
||
| from ..exc import InvalidConfig | ||
|
|
||
|
|
@@ -51,6 +49,21 @@ def differential_mutation(self, a: object, b: object, c: object) -> object: | |
| def is_block_size(self) -> bool: | ||
| return False | ||
|
|
||
| def is_categorical(self) -> bool: | ||
| return True | ||
|
|
||
| def encode_dim(self) -> int: | ||
| """ | ||
| Returns the dimension of the output of encode | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what this is? What is it encoding?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just the dimension of the encoding
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed the name to be more clear |
||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def encode(self, value: object) -> list[float]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we reuse
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jansel This is pretty much identical to encode_scalar for integer and poweroftwo. But the previous encode_scalar did not have any functionality for ListOf or PermutationFragments. How should we handle those two?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm happy to rename everything to encode_scalar, but given ListOf and PermutationFragments, I think we should allow encode to output a list of floats |
||
| """ | ||
| Returns a list of floats that can be used to encode the value of this fragment. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def get_minimum(self) -> int: | ||
| """ | ||
| Return the minimum allowed value for this fragment. | ||
|
|
@@ -106,6 +119,17 @@ def pattern_neighbors(self, current: object) -> list[object]: | |
| neighbors.append(swapped) | ||
| return neighbors | ||
|
|
||
| def encode_dim(self) -> int: | ||
| return self.length * self.length | ||
|
|
||
| def encode(self, value: object) -> list[float]: | ||
| encoded = [] | ||
| for pos in range(self.length): | ||
| val = value[pos] | ||
| for v in range(self.length): | ||
| encoded.append(1.0 if v == val else 0.0) | ||
| return encoded | ||
ethche marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class BaseIntegerFragment(ConfigSpecFragment): | ||
|
|
@@ -126,6 +150,9 @@ def default(self) -> int: | |
| def clamp(self, val: int) -> int: | ||
| return max(min(val, self.high), self.low) | ||
|
|
||
| def is_categorical(self) -> bool: | ||
| return False | ||
|
|
||
| def get_minimum(self) -> int: | ||
| return self.low | ||
|
|
||
|
|
@@ -141,13 +168,19 @@ def pattern_neighbors(self, current: object) -> list[object]: | |
| neighbors.append(upper) | ||
| return neighbors | ||
|
|
||
| def encode_scalar(self, value: object) -> float: | ||
| """Encode integer values directly as floats.""" | ||
| if not isinstance(value, (int, float)): | ||
| raise TypeError( | ||
| f"Expected int/float for BaseIntegerFragment, got {type(value).__name__}: {value!r}" | ||
| ) | ||
| return float(value) | ||
| def encode_dim(self) -> int: | ||
| return 1 | ||
|
|
||
| def encode(self, value: object) -> list[float]: | ||
| """Encode enum values as their index.""" | ||
| try: | ||
| choice_idx = self.choices.index(value) | ||
| except ValueError: | ||
| raise ValueError( | ||
| f"Invalid enum value {value!r} for EnumFragment. " | ||
| f"Valid choices: {self.choices}" | ||
| ) from None | ||
| return [float(choice_idx)] | ||
|
|
||
|
|
||
| class PowerOfTwoFragment(BaseIntegerFragment): | ||
|
|
@@ -180,7 +213,10 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: | |
| return self.clamp(ai * 2) | ||
| return ai | ||
|
|
||
| def encode_scalar(self, value: object) -> float: | ||
| def encode_dim(self) -> int: | ||
| return 1 | ||
|
|
||
| def encode(self, value: object) -> list[float]: | ||
| """Encode power-of-2 values using log2 transformation.""" | ||
| import math | ||
|
|
||
|
|
@@ -192,7 +228,7 @@ def encode_scalar(self, value: object) -> float: | |
| raise ValueError( | ||
| f"Expected positive value for PowerOfTwoFragment, got {value}" | ||
| ) | ||
| return math.log2(float(value)) | ||
| return [math.log2(float(value))] | ||
|
|
||
|
|
||
| class IntegerFragment(BaseIntegerFragment): | ||
|
|
@@ -211,6 +247,13 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: | |
| return self.clamp(a + 1) | ||
| return a | ||
|
|
||
| def encode_dim(self) -> int: | ||
| return 1 | ||
|
|
||
| def encode(self, value: object) -> list[float]: | ||
| assert isinstance(value, int) | ||
| return [float(value)] | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class EnumFragment(ConfigSpecFragment): | ||
|
|
@@ -235,7 +278,10 @@ def differential_mutation(self, a: object, b: object, c: object) -> object: | |
| choices.remove(a) | ||
| return random.choice(choices) | ||
|
|
||
| def encode_scalar(self, value: object) -> float: | ||
| def encode_dim(self) -> int: | ||
| return len(self.choices) | ||
|
|
||
| def encode(self, value: object) -> list[float]: | ||
| """Encode enum values as their index.""" | ||
| try: | ||
| choice_idx = self.choices.index(value) | ||
|
|
@@ -244,7 +290,7 @@ def encode_scalar(self, value: object) -> float: | |
| f"Invalid enum value {value!r} for EnumFragment. " | ||
| f"Valid choices: {self.choices}" | ||
| ) from None | ||
| return float(choice_idx) | ||
| return [1.0 if i == choice_idx else 0.0 for i in range(len(self.choices))] | ||
|
|
||
|
|
||
| class BooleanFragment(ConfigSpecFragment): | ||
|
|
@@ -265,6 +311,14 @@ def differential_mutation(self, a: object, b: object, c: object) -> bool: | |
| return a | ||
| return not a | ||
|
|
||
| def encode_dim(self) -> int: | ||
| return 1 | ||
|
|
||
| def encode(self, value: object) -> list[float]: | ||
| """Encode enum values as their index.""" | ||
| assert isinstance(value, bool) | ||
| return [1.0] if value else [0.0] | ||
|
|
||
|
|
||
| class BlockSizeFragment(PowerOfTwoFragment): | ||
| def category(self) -> Category: | ||
|
|
@@ -296,6 +350,9 @@ def random(self) -> list[object]: | |
| """Return a list of random values.""" | ||
| return [self.inner.random() for _ in range(self.length)] | ||
|
|
||
| def is_categorical(self) -> bool: | ||
| return self.inner.is_categorical() | ||
|
|
||
| def pattern_neighbors(self, current: object) -> list[object]: | ||
| """Return neighbors by changing one element at a time.""" | ||
| if not isinstance(current, list) or len(current) != self.length: | ||
|
|
@@ -320,3 +377,13 @@ def differential_mutation(self, a: object, b: object, c: object) -> list[object] | |
| self.inner.differential_mutation(a[i], b[i], c[i]) | ||
| for i in range(self.length) | ||
| ] | ||
|
|
||
| def encode_dim(self): | ||
| return self.length * self.inner.encode_dim() | ||
|
|
||
| def encode(self, value: object) -> list[float]: | ||
| assert isinstance(value, list) | ||
| encoded = [] | ||
| for v in value: | ||
| encoded.extend(self.inner.encode(v)) | ||
| return encoded | ||
Uh oh!
There was an error while loading. Please reload this page.