@@ -215,8 +215,7 @@ def get_nearest_neighbors(
215215
216216def get_categorical_neighbors (
217217 current_x : Tensor ,
218- bounds : Tensor ,
219- cat_dims : Tensor ,
218+ cat_dims : dict [int , list [float ]],
220219 max_num_cat_values : int = MAX_DISCRETE_VALUES ,
221220) -> Tensor :
222221 r"""Generate all 1-Hamming distance neighbors of a given input. The neighbors
@@ -231,8 +230,8 @@ def get_categorical_neighbors(
231230
232231 Args:
233232 current_x: The design to find the neighbors of. A tensor of shape `d`.
234- bounds : A `2 x d` tensor of lower and upper bounds for each column of `X`.
235- cat_dims: A tensor of indices corresponding to categorical parameters .
233+ cat_dims : A dictionary mapping indices of categorical dimensions
234+ to a list of allowed values for that dimension .
236235 max_num_cat_values: Maximum number of values for a categorical parameter,
237236 beyond which values are uniformly sampled.
238237
@@ -246,31 +245,31 @@ def get_categorical_neighbors(
246245 def _get_cat_values (dim : int ) -> Sequence [int ]:
247246 r"""Get a sequence of up to `max_num_cat_values` values that a categorical
248247 feature may take."""
249- lb , ub = bounds [:, dim ].long ()
250248 current_value = current_x [dim ]
251- cat_values = range (lb , ub + 1 )
252- if ub - lb + 1 <= max_num_cat_values :
253- return cat_values
249+ if len (cat_dims [dim ]) <= max_num_cat_values :
250+ return cat_dims [dim ]
254251 else :
255252 return random .sample (
256- [v for v in cat_values if v != current_value ], k = max_num_cat_values
253+ [v for v in cat_dims [ dim ] if v != current_value ], k = max_num_cat_values
257254 )
258255
256+ new_cat_values_dict = {dim : _get_cat_values (dim ) for dim in cat_dims .keys ()}
259257 new_cat_values_lst = list (
260- itertools .chain .from_iterable (_get_cat_values ( dim ) for dim in cat_dims )
258+ itertools .chain .from_iterable (new_cat_values_dict . values () )
261259 )
262260 new_cat_values = torch .tensor (
263261 new_cat_values_lst , device = current_x .device , dtype = current_x .dtype
264262 )
265263
266- num_cat_values = (bounds [1 , :] - bounds [0 , :] + 1 ).to (dtype = torch .long )
267- num_cat_values .clamp_ (max = max_num_cat_values )
268264 new_cat_idcs = torch .cat (
269265 tuple (
270- torch .full ((num_cat_values [dim ].item (),), dim , device = current_x .device )
271- for dim in cat_dims
266+ torch .full (
267+ (min (len (values ), max_num_cat_values ),), dim , device = current_x .device
268+ )
269+ for dim , values in new_cat_values_dict .items ()
272270 )
273271 )
272+
274273 neighbors = current_x .repeat (len (new_cat_values ), 1 )
275274 # Assign the new values to their corresponding columns.
276275 neighbors .scatter_ (1 , new_cat_idcs .view (- 1 , 1 ), new_cat_values .view (- 1 , 1 ))
@@ -285,7 +284,7 @@ def get_spray_points(
285284 X_baseline : Tensor ,
286285 cont_dims : Tensor ,
287286 discrete_dims : dict [int , list [float ]],
288- cat_dims : Tensor ,
287+ cat_dims : dict [ int , list [ float ]] ,
289288 bounds : Tensor ,
290289 num_spray_points : int ,
291290 std_cont_perturbation : float = STD_CONT_PERTURBATION ,
@@ -301,7 +300,8 @@ def get_spray_points(
301300 cont_dims: Indices of continuous parameters/input dimensions.
302301 discrete_dims: A dictionary mapping indices of discrete dimensions
303302 to a list of allowed values for that dimension.
304- cat_dims: Indices of categorical parameters/input dimensions.
303+ cat_dims: A dictionary mapping indices of categorical dimensions
304+ to a list of allowed values for that dimension.
305305 bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
306306 num_spray_points: Number of spray points to return.
307307 std_cont_perturbation: standard deviation of Normal perturbations of
@@ -316,6 +316,7 @@ def get_spray_points(
316316 t_discrete_dims = torch .tensor (
317317 list (discrete_dims .keys ()), dtype = torch .long , device = device
318318 )
319+ t_cat_dims = torch .tensor (list (cat_dims .keys ()), dtype = torch .long , device = device )
319320 for x in X_baseline :
320321 if len (discrete_dims ) > 0 :
321322 discrete_perturbs = get_nearest_neighbors (
@@ -326,10 +327,8 @@ def get_spray_points(
326327 len (discrete_perturbs ), (num_spray_points ,), device = device
327328 )
328329 ]
329- if cat_dims .numel ():
330- cat_perturbs = get_categorical_neighbors (
331- current_x = x , bounds = bounds , cat_dims = cat_dims
332- )
330+ if len (cat_dims ) > 0 :
331+ cat_perturbs = get_categorical_neighbors (current_x = x , cat_dims = cat_dims )
333332 cat_perturbs = cat_perturbs [
334333 torch .randint (len (cat_perturbs ), (num_spray_points ,), device = device )
335334 ]
@@ -343,8 +342,8 @@ def get_spray_points(
343342 nbds = torch .zeros (num_spray_points , dim , device = device , dtype = dtype )
344343 if len (discrete_dims ) > 0 :
345344 nbds [..., t_discrete_dims ] = discrete_perturbs [..., t_discrete_dims ]
346- if cat_dims . numel () :
347- nbds [..., cat_dims ] = cat_perturbs [..., cat_dims ]
345+ if len ( cat_dims ) > 0 :
346+ nbds [..., t_cat_dims ] = cat_perturbs [..., t_cat_dims ]
348347
349348 nbds [..., cont_dims ] = cont_perturbs
350349 perturb_nbors = torch .cat ([perturb_nbors , nbds ], dim = 0 )
@@ -354,7 +353,7 @@ def get_spray_points(
354353def sample_feasible_points (
355354 opt_inputs : OptimizeAcqfInputs ,
356355 discrete_dims : dict [int , list [float ]],
357- cat_dims : Tensor ,
356+ cat_dims : dict [ int , list [ float ]] ,
358357 num_points : int ,
359358) -> Tensor :
360359 r"""Sample feasible points from the optimization domain.
@@ -374,7 +373,8 @@ def sample_feasible_points(
374373 opt_inputs: Common set of arguments for acquisition optimization.
375374 discrete_dims: A dictionary mapping indices of discrete dimensions
376375 to a list of allowed values for that dimension.
377- cat_dims: A tensor of indices corresponding to categorical parameters.
376+ cat_dims: A dictionary mapping indices of categorical dimensions
377+ to a list of allowed values for that dimension.
378378 num_points: The number of points to sample.
379379
380380 Returns:
@@ -413,7 +413,7 @@ def generator(n: int) -> Tensor:
413413 base_points = generator (n = num_remaining * 2 )
414414 # Round the discrete dimensions to the nearest integer.
415415 base_points = round_discrete_dims (X = base_points , discrete_dims = discrete_dims )
416- base_points [:, cat_dims ] = base_points [:, cat_dims ]. round ( )
416+ base_points = round_discrete_dims ( X = base_points , discrete_dims = cat_dims )
417417 # Fix the fixed features.
418418 base_points = fix_features (
419419 X = base_points ,
@@ -457,7 +457,7 @@ def round_discrete_dims(X: Tensor, discrete_dims: dict[int, list[float]]) -> Ten
457457def generate_starting_points (
458458 opt_inputs : OptimizeAcqfInputs ,
459459 discrete_dims : dict [int , list [float ]],
460- cat_dims : Tensor ,
460+ cat_dims : dict [ int , list [ float ]] ,
461461 cont_dims : Tensor ,
462462) -> tuple [Tensor , Tensor ]:
463463 """Generate initial starting points for the alternating optimization.
@@ -472,7 +472,8 @@ def generate_starting_points(
472472 from `opt_inputs`.
473473 discrete_dims: A dictionary mapping indices of discrete dimensions
474474 to a list of allowed values for that dimension.
475- cat_dims: A tensor of indices corresponding to categorical parameters.
475+ cat_dims: A dictionary mapping indices of categorical dimensions
476+ to a list of allowed values for that dimension.
476477 cont_dims: A tensor of indices corresponding to continuous parameters.
477478
478479 Returns:
@@ -625,7 +626,7 @@ def generate_starting_points(
625626def discrete_step (
626627 opt_inputs : OptimizeAcqfInputs ,
627628 discrete_dims : dict [int , list [float ]],
628- cat_dims : Tensor ,
629+ cat_dims : dict [ int , list [ float ]] ,
629630 current_x : Tensor ,
630631) -> tuple [Tensor , Tensor ]:
631632 """Discrete nearest neighbour search.
@@ -636,7 +637,8 @@ def discrete_step(
636637 and constraints from `opt_inputs`.
637638 discrete_dims: A dictionary mapping indices of discrete dimensions
638639 to a list of allowed values for that dimension.
639- cat_dims: A tensor of indices corresponding to categorical parameters.
640+ cat_dims: A dictionary mapping indices of categorical dimensions
641+ to a list of allowed values for that dimension.
640642 current_x: Batch of starting points. A tensor of shape `b x d`.
641643
642644 Returns:
@@ -676,10 +678,9 @@ def discrete_step(
676678 neighbors .append (x_neighbors_discrete )
677679
678680 # if we have cat_dims look for neighbors by changing the cat's
679- if cat_dims . numel () :
681+ if len ( cat_dims ) > 0 :
680682 x_neighbors_cat = get_categorical_neighbors (
681683 current_x = current_x [i ].detach (),
682- bounds = opt_inputs .bounds ,
683684 cat_dims = cat_dims ,
684685 )
685686 x_neighbors_cat = _filter_infeasible (
@@ -807,7 +808,7 @@ def optimize_acqf_mixed_alternating(
807808 acq_function : AcquisitionFunction ,
808809 bounds : Tensor ,
809810 discrete_dims : dict [int , list [float ]] | None = None ,
810- cat_dims : list [int ] | None = None ,
811+ cat_dims : dict [int , list [ float ] ] | None = None ,
811812 options : dict [str , Any ] | None = None ,
812813 q : int = 1 ,
813814 raw_samples : int = RAW_SAMPLES ,
@@ -837,7 +838,8 @@ def optimize_acqf_mixed_alternating(
837838 bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
838839 discrete_dims: A dictionary mapping indices of discrete and binary
839840 dimensions to a list of allowed values for that dimension.
840- cat_dims: A list of indices corresponding to categorical parameters.
841+ cat_dims: A dictionary mapping indices of categorical dimensions
842+ to a list of allowed values for that dimension.
841843 options: Dictionary specifying optimization options. Supports the following:
842844 - "initialization_strategy": Strategy used to generate the initial candidates.
843845 "random", "continuous_relaxation" or "equally_spaced" (linspace style).
@@ -891,12 +893,15 @@ def optimize_acqf_mixed_alternating(
891893 "sequential optimization."
892894 )
893895
894- cat_dims = cat_dims or []
896+ cat_dims = cat_dims or {}
895897 discrete_dims = discrete_dims or {}
896898
897899 # sort the values in discrete dims in ascending order
898900 discrete_dims = {dim : sorted (values ) for dim , values in discrete_dims .items ()}
899901
902+ # sort the categorical dims in ascending order
903+ cat_dims = {dim : sorted (values ) for dim , values in cat_dims .items ()}
904+
900905 for dim , values in discrete_dims .items ():
901906 lower_bnd , upper_bnd = bounds [:, dim ].tolist ()
902907 lower , upper = values [0 ], values [- 1 ]
@@ -972,8 +977,10 @@ def optimize_acqf_mixed_alternating(
972977 for dim , values in discrete_dims .items ()
973978 if dim not in fixed_features
974979 }
975- cat_dims = [dim for dim in cat_dims if dim not in fixed_features ]
976- non_cont_dims = [* discrete_dims .keys (), * cat_dims ]
980+ cat_dims = {
981+ dim : values for dim , values in cat_dims .items () if dim not in fixed_features
982+ }
983+ non_cont_dims = [* discrete_dims .keys (), * cat_dims .keys ()]
977984 if len (non_cont_dims ) == 0 :
978985 # If the problem is fully continuous, fall back to standard optimization.
979986 return _optimize_acqf (
@@ -989,13 +996,15 @@ def optimize_acqf_mixed_alternating(
989996 and max (non_cont_dims ) <= dim - 1
990997 ):
991998 raise ValueError (
992- "`discrete_dims` and `cat_dims` must be lists with unique, disjoint "
993- "integers between 0 and num_dims - 1."
999+ "`discrete_dims` and `cat_dims` must be dictionaries with unique, disjoint "
1000+ "integers as keys between 0 and num_dims - 1."
9941001 )
9951002 discrete_dims_t = torch .tensor (
9961003 list (discrete_dims .keys ()), dtype = torch .long , device = tkwargs ["device" ]
9971004 )
998- cat_dims_t = torch .tensor (cat_dims , dtype = torch .long , device = tkwargs ["device" ])
1005+ cat_dims_t = torch .tensor (
1006+ list (cat_dims .keys ()), dtype = torch .long , device = tkwargs ["device" ]
1007+ )
9991008 non_cont_dims = torch .tensor (
10001009 non_cont_dims , dtype = torch .long , device = tkwargs ["device" ]
10011010 )
@@ -1011,7 +1020,7 @@ def optimize_acqf_mixed_alternating(
10111020 best_X , best_acq_val = generate_starting_points (
10121021 opt_inputs = opt_inputs ,
10131022 discrete_dims = discrete_dims ,
1014- cat_dims = cat_dims_t ,
1023+ cat_dims = cat_dims ,
10151024 cont_dims = cont_dims ,
10161025 )
10171026
@@ -1021,7 +1030,7 @@ def optimize_acqf_mixed_alternating(
10211030 best_X [~ done ], best_acq_val [~ done ] = discrete_step (
10221031 opt_inputs = opt_inputs ,
10231032 discrete_dims = discrete_dims ,
1024- cat_dims = cat_dims_t ,
1033+ cat_dims = cat_dims ,
10251034 current_x = best_X [~ done ],
10261035 )
10271036
0 commit comments