1414import math
1515
1616from abc import ABC
17+
18+ from contextlib import nullcontext
1719from copy import deepcopy
1820from typing import Dict , Optional , Tuple , Union
1921
@@ -441,6 +443,91 @@ def _compute_prob_feas(self, X: Tensor, means: Tensor, sigmas: Tensor) -> Tensor
441443 return prob_feas
442444
443445
446+ class LogNoisyExpectedImprovement (AnalyticAcquisitionFunction ):
447+ r"""Single-outcome Log Noisy Expected Improvement (via fantasies).
448+
449+ This computes Log Noisy Expected Improvement by averaging over the Expected
450+ Improvement values of a number of fantasy models. Only supports the case
451+ `q=1`. Assumes that the posterior distribution of the model is Gaussian.
452+ The model must be single-outcome.
453+
454+ `LogNEI(x) = log(E(max(y - max Y_base), 0))), (y, Y_base) ~ f((x, X_base))`,
455+ where `X_base` are previously observed points.
456+
457+ Note: This acquisition function currently relies on using a FixedNoiseGP (required
458+ for noiseless fantasies).
459+
460+ Example:
461+ >>> model = FixedNoiseGP(train_X, train_Y, train_Yvar=train_Yvar)
462+ >>> LogNEI = LogNoisyExpectedImprovement(model, train_X)
463+ >>> nei = LogNEI(test_X)
464+ """
465+
466+ def __init__ (
467+ self ,
468+ model : GPyTorchModel ,
469+ X_observed : Tensor ,
470+ num_fantasies : int = 20 ,
471+ maximize : bool = True ,
472+ posterior_transform : Optional [PosteriorTransform ] = None ,
473+ ** kwargs ,
474+ ) -> None :
475+ r"""Single-outcome Noisy Log Expected Improvement (via fantasies).
476+
477+ Args:
478+ model: A fitted single-outcome model.
479+ X_observed: A `n x d` Tensor of observed points that are likely to
480+ be the best observed points so far.
481+ num_fantasies: The number of fantasies to generate. The higher this
482+ number the more accurate the model (at the expense of model
483+ complexity and performance).
484+ maximize: If True, consider the problem a maximization problem.
485+ """
486+ if not isinstance (model , FixedNoiseGP ):
487+ raise UnsupportedError (
488+ "Only FixedNoiseGPs are currently supported for fantasy LogNEI"
489+ )
490+ # sample fantasies
491+ from botorch .sampling .normal import SobolQMCNormalSampler
492+
493+ # Drop gradients from model.posterior if X_observed does not require gradients
494+ # as otherwise, gradients of the GP's kernel's hyper-parameters are tracked
495+ # through the rsample_from_base_sample method of GPyTorchPosterior. These
496+ # gradients are usually only required w.r.t. the marginal likelihood.
497+ with nullcontext () if X_observed .requires_grad else torch .no_grad ():
498+ posterior = model .posterior (X = X_observed )
499+ sampler = SobolQMCNormalSampler (sample_shape = torch .Size ([num_fantasies ]))
500+ Y_fantasized = sampler (posterior ).squeeze (- 1 )
501+ batch_X_observed = X_observed .expand (num_fantasies , * X_observed .shape )
502+ # The fantasy model will operate in batch mode
503+ fantasy_model = _get_noiseless_fantasy_model (
504+ model = model , batch_X_observed = batch_X_observed , Y_fantasized = Y_fantasized
505+ )
506+ super ().__init__ (
507+ model = fantasy_model , posterior_transform = posterior_transform , ** kwargs
508+ )
509+ best_f , _ = Y_fantasized .max (dim = - 1 ) if maximize else Y_fantasized .min (dim = - 1 )
510+ self .best_f , self .maximize = best_f , maximize
511+
512+ @t_batch_mode_transform (expected_q = 1 )
513+ def forward (self , X : Tensor ) -> Tensor :
514+ r"""Evaluate logarithm of the mean Expected Improvement on the candidate set X.
515+
516+ Args:
517+ X: A `b1 x ... bk x 1 x d`-dim batched tensor of `d`-dim design points.
518+
519+ Returns:
520+ A `b1 x ... bk`-dim tensor of Log Noisy Expected Improvement values at
521+ the given design points `X`.
522+ """
523+ # add batch dimension for broadcasting to fantasy models
524+ mean , sigma = self ._mean_and_sigma (X .unsqueeze (- 3 ))
525+ u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
526+ log_ei = _log_ei_helper (u ) + sigma .log ()
527+ # this is mathematically - though not numerically - equivalent to log(mean(ei))
528+ return torch .logsumexp (log_ei , dim = - 1 ) - math .log (log_ei .shape [- 1 ])
529+
530+
444531class NoisyExpectedImprovement (ExpectedImprovement ):
445532 r"""Single-outcome Noisy Expected Improvement (via fantasies).
446533
@@ -486,10 +573,14 @@ def __init__(
486573 # sample fantasies
487574 from botorch .sampling .normal import SobolQMCNormalSampler
488575
489- with torch .no_grad ():
576+ # Drop gradients from model.posterior if X_observed does not require gradients
577+ # as otherwise, gradients of the GP's kernel's hyper-parameters are tracked
578+ # through the rsample_from_base_sample method of GPyTorchPosterior. These
579+ # gradients are usually only required w.r.t. the marginal likelihood.
580+ with nullcontext () if X_observed .requires_grad else torch .no_grad ():
490581 posterior = model .posterior (X = X_observed )
491- sampler = SobolQMCNormalSampler (sample_shape = torch .Size ([num_fantasies ]))
492- Y_fantasized = sampler (posterior ).squeeze (- 1 )
582+ sampler = SobolQMCNormalSampler (sample_shape = torch .Size ([num_fantasies ]))
583+ Y_fantasized = sampler (posterior ).squeeze (- 1 )
493584 batch_X_observed = X_observed .expand (num_fantasies , * X_observed .shape )
494585 # The fantasy model will operate in batch mode
495586 fantasy_model = _get_noiseless_fantasy_model (
@@ -515,45 +606,6 @@ def forward(self, X: Tensor) -> Tensor:
515606 return (sigma * _ei_helper (u )).mean (dim = - 1 )
516607
517608
518- def _get_noiseless_fantasy_model (
519- model : FixedNoiseGP , batch_X_observed : Tensor , Y_fantasized : Tensor
520- ) -> FixedNoiseGP :
521- r"""Construct a fantasy model from a fitted model and provided fantasies.
522-
523- The fantasy model uses the hyperparameters from the original fitted model and
524- assumes the fantasies are noiseless.
525-
526- Args:
527- model: a fitted FixedNoiseGP
528- batch_X_observed: A `b x n x d` tensor of inputs where `b` is the number of
529- fantasies.
530- Y_fantasized: A `b x n` tensor of fantasized targets where `b` is the number of
531- fantasies.
532-
533- Returns:
534- The fantasy model.
535- """
536- # initialize a copy of FixedNoiseGP on the original training inputs
537- # this makes FixedNoiseGP a non-batch GP, so that the same hyperparameters
538- # are used across all batches (by default, a GP with batched training data
539- # uses independent hyperparameters for each batch).
540- fantasy_model = FixedNoiseGP (
541- train_X = model .train_inputs [0 ],
542- train_Y = model .train_targets .unsqueeze (- 1 ),
543- train_Yvar = model .likelihood .noise_covar .noise .unsqueeze (- 1 ),
544- )
545- # update training inputs/targets to be batch mode fantasies
546- fantasy_model .set_train_data (
547- inputs = batch_X_observed , targets = Y_fantasized , strict = False
548- )
549- # use noiseless fantasies
550- fantasy_model .likelihood .noise_covar .noise = torch .full_like (Y_fantasized , 1e-7 )
551- # load hyperparameters from original model
552- state_dict = deepcopy (model .state_dict ())
553- fantasy_model .load_state_dict (state_dict )
554- return fantasy_model
555-
556-
557609class UpperConfidenceBound (AnalyticAcquisitionFunction ):
558610 r"""Single-outcome Upper Confidence Bound (UCB).
559611
@@ -807,3 +859,42 @@ def _construct_dist(means: Tensor, sigmas: Tensor, inds: Tensor) -> Normal:
807859 mean = means .index_select (dim = - 1 , index = inds )
808860 sigma = sigmas .index_select (dim = - 1 , index = inds )
809861 return Normal (loc = mean , scale = sigma )
862+
863+
864+ def _get_noiseless_fantasy_model (
865+ model : FixedNoiseGP , batch_X_observed : Tensor , Y_fantasized : Tensor
866+ ) -> FixedNoiseGP :
867+ r"""Construct a fantasy model from a fitted model and provided fantasies.
868+
869+ The fantasy model uses the hyperparameters from the original fitted model and
870+ assumes the fantasies are noiseless.
871+
872+ Args:
873+ model: a fitted FixedNoiseGP
874+ batch_X_observed: A `b x n x d` tensor of inputs where `b` is the number of
875+ fantasies.
876+ Y_fantasized: A `b x n` tensor of fantasized targets where `b` is the number of
877+ fantasies.
878+
879+ Returns:
880+ The fantasy model.
881+ """
882+ # initialize a copy of FixedNoiseGP on the original training inputs
883+ # this makes FixedNoiseGP a non-batch GP, so that the same hyperparameters
884+ # are used across all batches (by default, a GP with batched training data
885+ # uses independent hyperparameters for each batch).
886+ fantasy_model = FixedNoiseGP (
887+ train_X = model .train_inputs [0 ],
888+ train_Y = model .train_targets .unsqueeze (- 1 ),
889+ train_Yvar = model .likelihood .noise_covar .noise .unsqueeze (- 1 ),
890+ )
891+ # update training inputs/targets to be batch mode fantasies
892+ fantasy_model .set_train_data (
893+ inputs = batch_X_observed , targets = Y_fantasized , strict = False
894+ )
895+ # use noiseless fantasies
896+ fantasy_model .likelihood .noise_covar .noise = torch .full_like (Y_fantasized , 1e-7 )
897+ # load hyperparameters from original model
898+ state_dict = deepcopy (model .state_dict ())
899+ fantasy_model .load_state_dict (state_dict )
900+ return fantasy_model
0 commit comments