Skip to content

Commit 8f58784

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Introducing LogNoisyExpectedImprovement (#1577)
Summary: Pull Request resolved: #1577 Follow up on D41890063 (d819b2d), continuing the logarithmification of improvement-based acquisition functions. Notably, the existing test case for `NEI` already exhibits acquisition values that are exactly zero and gradients on the order of machine epsilon. This is solved by `LogNEI`. Reviewed By: Balandat Differential Revision: D42109272 fbshipit-source-id: 96ca29dfeba3687b708beff2a21fb3da4ae98d2b
1 parent 37e5061 commit 8f58784

File tree

3 files changed

+184
-54
lines changed

3 files changed

+184
-54
lines changed

botorch/acquisition/analytic.py

Lines changed: 133 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import math
1515

1616
from abc import ABC
17+
18+
from contextlib import nullcontext
1719
from copy import deepcopy
1820
from typing import Dict, Optional, Tuple, Union
1921

@@ -441,6 +443,91 @@ def _compute_prob_feas(self, X: Tensor, means: Tensor, sigmas: Tensor) -> Tensor
441443
return prob_feas
442444

443445

446+
class LogNoisyExpectedImprovement(AnalyticAcquisitionFunction):
447+
r"""Single-outcome Log Noisy Expected Improvement (via fantasies).
448+
449+
This computes Log Noisy Expected Improvement by averaging over the Expected
450+
Improvement values of a number of fantasy models. Only supports the case
451+
`q=1`. Assumes that the posterior distribution of the model is Gaussian.
452+
The model must be single-outcome.
453+
454+
`LogNEI(x) = log(E(max(y - max Y_base), 0))), (y, Y_base) ~ f((x, X_base))`,
455+
where `X_base` are previously observed points.
456+
457+
Note: This acquisition function currently relies on using a FixedNoiseGP (required
458+
for noiseless fantasies).
459+
460+
Example:
461+
>>> model = FixedNoiseGP(train_X, train_Y, train_Yvar=train_Yvar)
462+
>>> LogNEI = LogNoisyExpectedImprovement(model, train_X)
463+
>>> nei = LogNEI(test_X)
464+
"""
465+
466+
def __init__(
467+
self,
468+
model: GPyTorchModel,
469+
X_observed: Tensor,
470+
num_fantasies: int = 20,
471+
maximize: bool = True,
472+
posterior_transform: Optional[PosteriorTransform] = None,
473+
**kwargs,
474+
) -> None:
475+
r"""Single-outcome Noisy Log Expected Improvement (via fantasies).
476+
477+
Args:
478+
model: A fitted single-outcome model.
479+
X_observed: A `n x d` Tensor of observed points that are likely to
480+
be the best observed points so far.
481+
num_fantasies: The number of fantasies to generate. The higher this
482+
number the more accurate the model (at the expense of model
483+
complexity and performance).
484+
maximize: If True, consider the problem a maximization problem.
485+
"""
486+
if not isinstance(model, FixedNoiseGP):
487+
raise UnsupportedError(
488+
"Only FixedNoiseGPs are currently supported for fantasy LogNEI"
489+
)
490+
# sample fantasies
491+
from botorch.sampling.normal import SobolQMCNormalSampler
492+
493+
# Drop gradients from model.posterior if X_observed does not require gradients
494+
# as otherwise, gradients of the GP's kernel's hyper-parameters are tracked
495+
# through the rsample_from_base_sample method of GPyTorchPosterior. These
496+
# gradients are usually only required w.r.t. the marginal likelihood.
497+
with nullcontext() if X_observed.requires_grad else torch.no_grad():
498+
posterior = model.posterior(X=X_observed)
499+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
500+
Y_fantasized = sampler(posterior).squeeze(-1)
501+
batch_X_observed = X_observed.expand(num_fantasies, *X_observed.shape)
502+
# The fantasy model will operate in batch mode
503+
fantasy_model = _get_noiseless_fantasy_model(
504+
model=model, batch_X_observed=batch_X_observed, Y_fantasized=Y_fantasized
505+
)
506+
super().__init__(
507+
model=fantasy_model, posterior_transform=posterior_transform, **kwargs
508+
)
509+
best_f, _ = Y_fantasized.max(dim=-1) if maximize else Y_fantasized.min(dim=-1)
510+
self.best_f, self.maximize = best_f, maximize
511+
512+
@t_batch_mode_transform(expected_q=1)
513+
def forward(self, X: Tensor) -> Tensor:
514+
r"""Evaluate logarithm of the mean Expected Improvement on the candidate set X.
515+
516+
Args:
517+
X: A `b1 x ... bk x 1 x d`-dim batched tensor of `d`-dim design points.
518+
519+
Returns:
520+
A `b1 x ... bk`-dim tensor of Log Noisy Expected Improvement values at
521+
the given design points `X`.
522+
"""
523+
# add batch dimension for broadcasting to fantasy models
524+
mean, sigma = self._mean_and_sigma(X.unsqueeze(-3))
525+
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
526+
log_ei = _log_ei_helper(u) + sigma.log()
527+
# this is mathematically - though not numerically - equivalent to log(mean(ei))
528+
return torch.logsumexp(log_ei, dim=-1) - math.log(log_ei.shape[-1])
529+
530+
444531
class NoisyExpectedImprovement(ExpectedImprovement):
445532
r"""Single-outcome Noisy Expected Improvement (via fantasies).
446533
@@ -486,10 +573,14 @@ def __init__(
486573
# sample fantasies
487574
from botorch.sampling.normal import SobolQMCNormalSampler
488575

489-
with torch.no_grad():
576+
# Drop gradients from model.posterior if X_observed does not require gradients
577+
# as otherwise, gradients of the GP's kernel's hyper-parameters are tracked
578+
# through the rsample_from_base_sample method of GPyTorchPosterior. These
579+
# gradients are usually only required w.r.t. the marginal likelihood.
580+
with nullcontext() if X_observed.requires_grad else torch.no_grad():
490581
posterior = model.posterior(X=X_observed)
491-
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
492-
Y_fantasized = sampler(posterior).squeeze(-1)
582+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
583+
Y_fantasized = sampler(posterior).squeeze(-1)
493584
batch_X_observed = X_observed.expand(num_fantasies, *X_observed.shape)
494585
# The fantasy model will operate in batch mode
495586
fantasy_model = _get_noiseless_fantasy_model(
@@ -515,45 +606,6 @@ def forward(self, X: Tensor) -> Tensor:
515606
return (sigma * _ei_helper(u)).mean(dim=-1)
516607

517608

518-
def _get_noiseless_fantasy_model(
519-
model: FixedNoiseGP, batch_X_observed: Tensor, Y_fantasized: Tensor
520-
) -> FixedNoiseGP:
521-
r"""Construct a fantasy model from a fitted model and provided fantasies.
522-
523-
The fantasy model uses the hyperparameters from the original fitted model and
524-
assumes the fantasies are noiseless.
525-
526-
Args:
527-
model: a fitted FixedNoiseGP
528-
batch_X_observed: A `b x n x d` tensor of inputs where `b` is the number of
529-
fantasies.
530-
Y_fantasized: A `b x n` tensor of fantasized targets where `b` is the number of
531-
fantasies.
532-
533-
Returns:
534-
The fantasy model.
535-
"""
536-
# initialize a copy of FixedNoiseGP on the original training inputs
537-
# this makes FixedNoiseGP a non-batch GP, so that the same hyperparameters
538-
# are used across all batches (by default, a GP with batched training data
539-
# uses independent hyperparameters for each batch).
540-
fantasy_model = FixedNoiseGP(
541-
train_X=model.train_inputs[0],
542-
train_Y=model.train_targets.unsqueeze(-1),
543-
train_Yvar=model.likelihood.noise_covar.noise.unsqueeze(-1),
544-
)
545-
# update training inputs/targets to be batch mode fantasies
546-
fantasy_model.set_train_data(
547-
inputs=batch_X_observed, targets=Y_fantasized, strict=False
548-
)
549-
# use noiseless fantasies
550-
fantasy_model.likelihood.noise_covar.noise = torch.full_like(Y_fantasized, 1e-7)
551-
# load hyperparameters from original model
552-
state_dict = deepcopy(model.state_dict())
553-
fantasy_model.load_state_dict(state_dict)
554-
return fantasy_model
555-
556-
557609
class UpperConfidenceBound(AnalyticAcquisitionFunction):
558610
r"""Single-outcome Upper Confidence Bound (UCB).
559611
@@ -807,3 +859,42 @@ def _construct_dist(means: Tensor, sigmas: Tensor, inds: Tensor) -> Normal:
807859
mean = means.index_select(dim=-1, index=inds)
808860
sigma = sigmas.index_select(dim=-1, index=inds)
809861
return Normal(loc=mean, scale=sigma)
862+
863+
864+
def _get_noiseless_fantasy_model(
865+
model: FixedNoiseGP, batch_X_observed: Tensor, Y_fantasized: Tensor
866+
) -> FixedNoiseGP:
867+
r"""Construct a fantasy model from a fitted model and provided fantasies.
868+
869+
The fantasy model uses the hyperparameters from the original fitted model and
870+
assumes the fantasies are noiseless.
871+
872+
Args:
873+
model: a fitted FixedNoiseGP
874+
batch_X_observed: A `b x n x d` tensor of inputs where `b` is the number of
875+
fantasies.
876+
Y_fantasized: A `b x n` tensor of fantasized targets where `b` is the number of
877+
fantasies.
878+
879+
Returns:
880+
The fantasy model.
881+
"""
882+
# initialize a copy of FixedNoiseGP on the original training inputs
883+
# this makes FixedNoiseGP a non-batch GP, so that the same hyperparameters
884+
# are used across all batches (by default, a GP with batched training data
885+
# uses independent hyperparameters for each batch).
886+
fantasy_model = FixedNoiseGP(
887+
train_X=model.train_inputs[0],
888+
train_Y=model.train_targets.unsqueeze(-1),
889+
train_Yvar=model.likelihood.noise_covar.noise.unsqueeze(-1),
890+
)
891+
# update training inputs/targets to be batch mode fantasies
892+
fantasy_model.set_train_data(
893+
inputs=batch_X_observed, targets=Y_fantasized, strict=False
894+
)
895+
# use noiseless fantasies
896+
fantasy_model.likelihood.noise_covar.noise = torch.full_like(Y_fantasized, 1e-7)
897+
# load hyperparameters from original model
898+
state_dict = deepcopy(model.state_dict())
899+
fantasy_model.load_state_dict(state_dict)
900+
return fantasy_model

botorch/acquisition/input_constructors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ConstrainedExpectedImprovement,
3434
ExpectedImprovement,
3535
LogExpectedImprovement,
36+
LogNoisyExpectedImprovement,
3637
NoisyExpectedImprovement,
3738
PosteriorMean,
3839
ProbabilityOfImprovement,
@@ -380,7 +381,7 @@ def construct_inputs_constrained_ei(
380381
raise NotImplementedError # pragma: nocover
381382

382383

383-
@acqf_input_constructor(NoisyExpectedImprovement)
384+
@acqf_input_constructor(NoisyExpectedImprovement, LogNoisyExpectedImprovement)
384385
def construct_inputs_noisy_ei(
385386
model: Model,
386387
training_data: MaybeDict[SupervisedDataset],

test/acquisition/test_analytic.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ConstrainedExpectedImprovement,
1515
ExpectedImprovement,
1616
LogExpectedImprovement,
17+
LogNoisyExpectedImprovement,
1718
NoisyExpectedImprovement,
1819
PosteriorMean,
1920
ProbabilityOfImprovement,
@@ -551,14 +552,33 @@ def test_noisy_expected_improvement(self):
551552
for dtype in (torch.float, torch.double):
552553
model = self._get_model(dtype=dtype)
553554
X_observed = model.train_inputs[0]
554-
nEI = NoisyExpectedImprovement(model, X_observed, num_fantasies=5)
555+
nfan = 5
556+
nEI = NoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
557+
LogNEI = LogNoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
558+
# before assigning, check that the attributes exist
559+
self.assertTrue(hasattr(LogNEI, "model"))
560+
self.assertTrue(hasattr(LogNEI, "best_f"))
561+
self.assertTrue(isinstance(LogNEI.model, FixedNoiseGP))
562+
LogNEI.model = nEI.model # let the two share their values and fantasies
563+
LogNEI.best_f = nEI.best_f
564+
555565
X_test = torch.tensor(
556566
[[[0.25]], [[0.75]]],
557567
device=X_observed.device,
558568
dtype=dtype,
559-
requires_grad=True,
560569
)
570+
X_test_log = X_test.clone()
571+
X_test.requires_grad = True
572+
X_test_log.requires_grad = True
561573
val = nEI(X_test)
574+
# testing logNEI yields the same result (also checks dtype)
575+
log_val = LogNEI(X_test_log)
576+
exp_log_val = log_val.exp()
577+
# notably, val[1] is usually zero in this test, which is precisely what
578+
# gives rise to problems during optimization, and what logNEI avoids
579+
# since it generally takes a large negative number (<-2000) and has
580+
# strong gradient signals in this regime.
581+
self.assertTrue(torch.allclose(exp_log_val, val))
562582
# test basics
563583
self.assertEqual(val.dtype, dtype)
564584
self.assertEqual(val.device.type, X_observed.device.type)
@@ -569,17 +589,35 @@ def test_noisy_expected_improvement(self):
569589
# test gradient
570590
val.sum().backward()
571591
self.assertGreater(X_test.grad[0].abs().item(), 1e-5)
572-
# test without gradient
573-
with torch.no_grad():
574-
nEI(X_test)
592+
# testing gradient through exp of log computation
593+
exp_log_val.sum().backward()
594+
# testing that first gradient element coincides. The second is in the
595+
# regime where the naive implementation looses accuracy.
596+
atol = 1e-5 if dtype == torch.float32 else 1e-14
597+
self.assertTrue(
598+
torch.allclose(X_test.grad[0], X_test_log.grad[0], atol=atol)
599+
)
600+
575601
# test non-FixedNoiseGP model
576602
other_model = SingleTaskGP(X_observed, model.train_targets.unsqueeze(-1))
577-
with self.assertRaises(UnsupportedError):
578-
NoisyExpectedImprovement(other_model, X_observed, num_fantasies=5)
579-
# Test with minimize
580-
nEI = NoisyExpectedImprovement(
581-
model, X_observed, num_fantasies=5, maximize=False
582-
)
603+
for constructor in (NoisyExpectedImprovement, LogNoisyExpectedImprovement):
604+
with self.assertRaises(UnsupportedError):
605+
constructor(other_model, X_observed, num_fantasies=5)
606+
# Test constructor with minimize
607+
acqf = constructor(model, X_observed, num_fantasies=5, maximize=False)
608+
# test evaluation without gradients enabled
609+
with torch.no_grad():
610+
acqf(X_test)
611+
612+
# testing gradients are only propagated if X_observed requires them
613+
# i.e. kernel hyper-parameters are not tracked through to best_f
614+
X_observed.requires_grad = False
615+
acqf = constructor(model, X_observed, num_fantasies=5)
616+
self.assertFalse(acqf.best_f.requires_grad)
617+
618+
X_observed.requires_grad = True
619+
acqf = constructor(model, X_observed, num_fantasies=5)
620+
self.assertTrue(acqf.best_f.requires_grad)
583621

584622

585623
class TestScalarizedPosteriorMean(BotorchTestCase):

0 commit comments

Comments
 (0)