Skip to content

Commit 2d5deba

Browse files
jduerholtfacebook-github-bot
authored andcommitted
Update syntax for categoricals in optimize_acqf_mixed_alternating (#2942)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation Following #2923, this PR updates the syntax for the categorical dimensions in order to match the syntax for the discrete ones. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #2942 Test Plan: Unit tests. Reviewed By: saitcakmak Differential Revision: D79091028 Pulled By: esantorella fbshipit-source-id: 6324bfb2e78eba13d912c243e0a405101e33c2dc
1 parent 0cb012f commit 2d5deba

File tree

2 files changed

+103
-65
lines changed

2 files changed

+103
-65
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ def get_nearest_neighbors(
215215

216216
def get_categorical_neighbors(
217217
current_x: Tensor,
218-
bounds: Tensor,
219-
cat_dims: Tensor,
218+
cat_dims: dict[int, list[float]],
220219
max_num_cat_values: int = MAX_DISCRETE_VALUES,
221220
) -> Tensor:
222221
r"""Generate all 1-Hamming distance neighbors of a given input. The neighbors
@@ -231,8 +230,8 @@ def get_categorical_neighbors(
231230
232231
Args:
233232
current_x: The design to find the neighbors of. A tensor of shape `d`.
234-
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
235-
cat_dims: A tensor of indices corresponding to categorical parameters.
233+
cat_dims: A dictionary mapping indices of categorical dimensions
234+
to a list of allowed values for that dimension.
236235
max_num_cat_values: Maximum number of values for a categorical parameter,
237236
beyond which values are uniformly sampled.
238237
@@ -246,31 +245,31 @@ def get_categorical_neighbors(
246245
def _get_cat_values(dim: int) -> Sequence[int]:
247246
r"""Get a sequence of up to `max_num_cat_values` values that a categorical
248247
feature may take."""
249-
lb, ub = bounds[:, dim].long()
250248
current_value = current_x[dim]
251-
cat_values = range(lb, ub + 1)
252-
if ub - lb + 1 <= max_num_cat_values:
253-
return cat_values
249+
if len(cat_dims[dim]) <= max_num_cat_values:
250+
return cat_dims[dim]
254251
else:
255252
return random.sample(
256-
[v for v in cat_values if v != current_value], k=max_num_cat_values
253+
[v for v in cat_dims[dim] if v != current_value], k=max_num_cat_values
257254
)
258255

256+
new_cat_values_dict = {dim: _get_cat_values(dim) for dim in cat_dims.keys()}
259257
new_cat_values_lst = list(
260-
itertools.chain.from_iterable(_get_cat_values(dim) for dim in cat_dims)
258+
itertools.chain.from_iterable(new_cat_values_dict.values())
261259
)
262260
new_cat_values = torch.tensor(
263261
new_cat_values_lst, device=current_x.device, dtype=current_x.dtype
264262
)
265263

266-
num_cat_values = (bounds[1, :] - bounds[0, :] + 1).to(dtype=torch.long)
267-
num_cat_values.clamp_(max=max_num_cat_values)
268264
new_cat_idcs = torch.cat(
269265
tuple(
270-
torch.full((num_cat_values[dim].item(),), dim, device=current_x.device)
271-
for dim in cat_dims
266+
torch.full(
267+
(min(len(values), max_num_cat_values),), dim, device=current_x.device
268+
)
269+
for dim, values in new_cat_values_dict.items()
272270
)
273271
)
272+
274273
neighbors = current_x.repeat(len(new_cat_values), 1)
275274
# Assign the new values to their corresponding columns.
276275
neighbors.scatter_(1, new_cat_idcs.view(-1, 1), new_cat_values.view(-1, 1))
@@ -285,7 +284,7 @@ def get_spray_points(
285284
X_baseline: Tensor,
286285
cont_dims: Tensor,
287286
discrete_dims: dict[int, list[float]],
288-
cat_dims: Tensor,
287+
cat_dims: dict[int, list[float]],
289288
bounds: Tensor,
290289
num_spray_points: int,
291290
std_cont_perturbation: float = STD_CONT_PERTURBATION,
@@ -301,7 +300,8 @@ def get_spray_points(
301300
cont_dims: Indices of continuous parameters/input dimensions.
302301
discrete_dims: A dictionary mapping indices of discrete dimensions
303302
to a list of allowed values for that dimension.
304-
cat_dims: Indices of categorical parameters/input dimensions.
303+
cat_dims: A dictionary mapping indices of categorical dimensions
304+
to a list of allowed values for that dimension.
305305
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
306306
num_spray_points: Number of spray points to return.
307307
std_cont_perturbation: standard deviation of Normal perturbations of
@@ -316,6 +316,7 @@ def get_spray_points(
316316
t_discrete_dims = torch.tensor(
317317
list(discrete_dims.keys()), dtype=torch.long, device=device
318318
)
319+
t_cat_dims = torch.tensor(list(cat_dims.keys()), dtype=torch.long, device=device)
319320
for x in X_baseline:
320321
if len(discrete_dims) > 0:
321322
discrete_perturbs = get_nearest_neighbors(
@@ -326,10 +327,8 @@ def get_spray_points(
326327
len(discrete_perturbs), (num_spray_points,), device=device
327328
)
328329
]
329-
if cat_dims.numel():
330-
cat_perturbs = get_categorical_neighbors(
331-
current_x=x, bounds=bounds, cat_dims=cat_dims
332-
)
330+
if len(cat_dims) > 0:
331+
cat_perturbs = get_categorical_neighbors(current_x=x, cat_dims=cat_dims)
333332
cat_perturbs = cat_perturbs[
334333
torch.randint(len(cat_perturbs), (num_spray_points,), device=device)
335334
]
@@ -343,8 +342,8 @@ def get_spray_points(
343342
nbds = torch.zeros(num_spray_points, dim, device=device, dtype=dtype)
344343
if len(discrete_dims) > 0:
345344
nbds[..., t_discrete_dims] = discrete_perturbs[..., t_discrete_dims]
346-
if cat_dims.numel():
347-
nbds[..., cat_dims] = cat_perturbs[..., cat_dims]
345+
if len(cat_dims) > 0:
346+
nbds[..., t_cat_dims] = cat_perturbs[..., t_cat_dims]
348347

349348
nbds[..., cont_dims] = cont_perturbs
350349
perturb_nbors = torch.cat([perturb_nbors, nbds], dim=0)
@@ -354,7 +353,7 @@ def get_spray_points(
354353
def sample_feasible_points(
355354
opt_inputs: OptimizeAcqfInputs,
356355
discrete_dims: dict[int, list[float]],
357-
cat_dims: Tensor,
356+
cat_dims: dict[int, list[float]],
358357
num_points: int,
359358
) -> Tensor:
360359
r"""Sample feasible points from the optimization domain.
@@ -374,7 +373,8 @@ def sample_feasible_points(
374373
opt_inputs: Common set of arguments for acquisition optimization.
375374
discrete_dims: A dictionary mapping indices of discrete dimensions
376375
to a list of allowed values for that dimension.
377-
cat_dims: A tensor of indices corresponding to categorical parameters.
376+
cat_dims: A dictionary mapping indices of categorical dimensions
377+
to a list of allowed values for that dimension.
378378
num_points: The number of points to sample.
379379
380380
Returns:
@@ -413,7 +413,7 @@ def generator(n: int) -> Tensor:
413413
base_points = generator(n=num_remaining * 2)
414414
# Round the discrete dimensions to the nearest integer.
415415
base_points = round_discrete_dims(X=base_points, discrete_dims=discrete_dims)
416-
base_points[:, cat_dims] = base_points[:, cat_dims].round()
416+
base_points = round_discrete_dims(X=base_points, discrete_dims=cat_dims)
417417
# Fix the fixed features.
418418
base_points = fix_features(
419419
X=base_points,
@@ -457,7 +457,7 @@ def round_discrete_dims(X: Tensor, discrete_dims: dict[int, list[float]]) -> Ten
457457
def generate_starting_points(
458458
opt_inputs: OptimizeAcqfInputs,
459459
discrete_dims: dict[int, list[float]],
460-
cat_dims: Tensor,
460+
cat_dims: dict[int, list[float]],
461461
cont_dims: Tensor,
462462
) -> tuple[Tensor, Tensor]:
463463
"""Generate initial starting points for the alternating optimization.
@@ -472,7 +472,8 @@ def generate_starting_points(
472472
from `opt_inputs`.
473473
discrete_dims: A dictionary mapping indices of discrete dimensions
474474
to a list of allowed values for that dimension.
475-
cat_dims: A tensor of indices corresponding to categorical parameters.
475+
cat_dims: A dictionary mapping indices of categorical dimensions
476+
to a list of allowed values for that dimension.
476477
cont_dims: A tensor of indices corresponding to continuous parameters.
477478
478479
Returns:
@@ -625,7 +626,7 @@ def generate_starting_points(
625626
def discrete_step(
626627
opt_inputs: OptimizeAcqfInputs,
627628
discrete_dims: dict[int, list[float]],
628-
cat_dims: Tensor,
629+
cat_dims: dict[int, list[float]],
629630
current_x: Tensor,
630631
) -> tuple[Tensor, Tensor]:
631632
"""Discrete nearest neighbour search.
@@ -636,7 +637,8 @@ def discrete_step(
636637
and constraints from `opt_inputs`.
637638
discrete_dims: A dictionary mapping indices of discrete dimensions
638639
to a list of allowed values for that dimension.
639-
cat_dims: A tensor of indices corresponding to categorical parameters.
640+
cat_dims: A dictionary mapping indices of categorical dimensions
641+
to a list of allowed values for that dimension.
640642
current_x: Batch of starting points. A tensor of shape `b x d`.
641643
642644
Returns:
@@ -676,10 +678,9 @@ def discrete_step(
676678
neighbors.append(x_neighbors_discrete)
677679

678680
# if we have cat_dims look for neighbors by changing the cat's
679-
if cat_dims.numel():
681+
if len(cat_dims) > 0:
680682
x_neighbors_cat = get_categorical_neighbors(
681683
current_x=current_x[i].detach(),
682-
bounds=opt_inputs.bounds,
683684
cat_dims=cat_dims,
684685
)
685686
x_neighbors_cat = _filter_infeasible(
@@ -807,7 +808,7 @@ def optimize_acqf_mixed_alternating(
807808
acq_function: AcquisitionFunction,
808809
bounds: Tensor,
809810
discrete_dims: dict[int, list[float]] | None = None,
810-
cat_dims: list[int] | None = None,
811+
cat_dims: dict[int, list[float]] | None = None,
811812
options: dict[str, Any] | None = None,
812813
q: int = 1,
813814
raw_samples: int = RAW_SAMPLES,
@@ -837,7 +838,8 @@ def optimize_acqf_mixed_alternating(
837838
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
838839
discrete_dims: A dictionary mapping indices of discrete and binary
839840
dimensions to a list of allowed values for that dimension.
840-
cat_dims: A list of indices corresponding to categorical parameters.
841+
cat_dims: A dictionary mapping indices of categorical dimensions
842+
to a list of allowed values for that dimension.
841843
options: Dictionary specifying optimization options. Supports the following:
842844
- "initialization_strategy": Strategy used to generate the initial candidates.
843845
"random", "continuous_relaxation" or "equally_spaced" (linspace style).
@@ -891,12 +893,15 @@ def optimize_acqf_mixed_alternating(
891893
"sequential optimization."
892894
)
893895

894-
cat_dims = cat_dims or []
896+
cat_dims = cat_dims or {}
895897
discrete_dims = discrete_dims or {}
896898

897899
# sort the values in discrete dims in ascending order
898900
discrete_dims = {dim: sorted(values) for dim, values in discrete_dims.items()}
899901

902+
# sort the categorical dims in ascending order
903+
cat_dims = {dim: sorted(values) for dim, values in cat_dims.items()}
904+
900905
for dim, values in discrete_dims.items():
901906
lower_bnd, upper_bnd = bounds[:, dim].tolist()
902907
lower, upper = values[0], values[-1]
@@ -972,8 +977,10 @@ def optimize_acqf_mixed_alternating(
972977
for dim, values in discrete_dims.items()
973978
if dim not in fixed_features
974979
}
975-
cat_dims = [dim for dim in cat_dims if dim not in fixed_features]
976-
non_cont_dims = [*discrete_dims.keys(), *cat_dims]
980+
cat_dims = {
981+
dim: values for dim, values in cat_dims.items() if dim not in fixed_features
982+
}
983+
non_cont_dims = [*discrete_dims.keys(), *cat_dims.keys()]
977984
if len(non_cont_dims) == 0:
978985
# If the problem is fully continuous, fall back to standard optimization.
979986
return _optimize_acqf(
@@ -989,13 +996,15 @@ def optimize_acqf_mixed_alternating(
989996
and max(non_cont_dims) <= dim - 1
990997
):
991998
raise ValueError(
992-
"`discrete_dims` and `cat_dims` must be lists with unique, disjoint "
993-
"integers between 0 and num_dims - 1."
999+
"`discrete_dims` and `cat_dims` must be dictionaries with unique, disjoint "
1000+
"integers as keys between 0 and num_dims - 1."
9941001
)
9951002
discrete_dims_t = torch.tensor(
9961003
list(discrete_dims.keys()), dtype=torch.long, device=tkwargs["device"]
9971004
)
998-
cat_dims_t = torch.tensor(cat_dims, dtype=torch.long, device=tkwargs["device"])
1005+
cat_dims_t = torch.tensor(
1006+
list(cat_dims.keys()), dtype=torch.long, device=tkwargs["device"]
1007+
)
9991008
non_cont_dims = torch.tensor(
10001009
non_cont_dims, dtype=torch.long, device=tkwargs["device"]
10011010
)
@@ -1011,7 +1020,7 @@ def optimize_acqf_mixed_alternating(
10111020
best_X, best_acq_val = generate_starting_points(
10121021
opt_inputs=opt_inputs,
10131022
discrete_dims=discrete_dims,
1014-
cat_dims=cat_dims_t,
1023+
cat_dims=cat_dims,
10151024
cont_dims=cont_dims,
10161025
)
10171026

@@ -1021,7 +1030,7 @@ def optimize_acqf_mixed_alternating(
10211030
best_X[~done], best_acq_val[~done] = discrete_step(
10221031
opt_inputs=opt_inputs,
10231032
discrete_dims=discrete_dims,
1024-
cat_dims=cat_dims_t,
1033+
cat_dims=cat_dims,
10251034
current_x=best_X[~done],
10261035
)
10271036

0 commit comments

Comments
 (0)