Skip to content

Commit 76062a6

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Extending t_batch_mode_transform usage for improvement-based acquisition functions (#1575)
Summary: Pull Request resolved: #1575 The shape checking of `t_batch_mode_transform` was previously turned off for `ExpectedImprovement` because `NoisyExpectedImprovement` broke it. This diff adds support for shape checking for EI, NEI, and logEI. Reviewed By: Balandat Differential Revision: D42103478 fbshipit-source-id: ddd4615d36f489fea1e54bee6ede2c7aef484d96
1 parent 64f53b5 commit 76062a6

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

botorch/acquisition/analytic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(
210210
self.register_buffer("best_f", torch.as_tensor(best_f))
211211
self.maximize = maximize
212212

213-
@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
213+
@t_batch_mode_transform(expected_q=1)
214214
def forward(self, X: Tensor) -> Tensor:
215215
r"""Evaluate Expected Improvement on the candidate set X.
216216
@@ -270,7 +270,7 @@ def __init__(
270270
self.register_buffer("best_f", torch.as_tensor(best_f))
271271
self.maximize = maximize
272272

273-
@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
273+
@t_batch_mode_transform(expected_q=1)
274274
def forward(self, X: Tensor) -> Tensor:
275275
r"""Evaluate logarithm of Expected Improvement on the candidate set X.
276276
@@ -498,6 +498,7 @@ def __init__(
498498
best_f, _ = Y_fantasized.max(dim=-1) if maximize else Y_fantasized.min(dim=-1)
499499
super().__init__(model=fantasy_model, best_f=best_f, maximize=maximize)
500500

501+
@t_batch_mode_transform(expected_q=1)
501502
def forward(self, X: Tensor) -> Tensor:
502503
r"""Evaluate Expected Improvement on the candidate set X.
503504
@@ -509,7 +510,9 @@ def forward(self, X: Tensor) -> Tensor:
509510
the given design points `X`.
510511
"""
511512
# add batch dimension for broadcasting to fantasy models
512-
return super().forward(X.unsqueeze(-3)).mean(dim=-1)
513+
mean, sigma = self._mean_and_sigma(X.unsqueeze(-3))
514+
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
515+
return (sigma * _ei_helper(u)).mean(dim=-1)
513516

514517

515518
def _get_noiseless_fantasy_model(

0 commit comments

Comments
 (0)