Skip to content

Commit 853a4c4

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Multi-Output Acquisition Functions (#2935)
Summary: Pull Request resolved: #2935 This adds implements an abstract MultiOutputAcquisitionFunction in botorch, as well as two subclasses: - A MultiOutputPosteriorMean - A wrapper around multiple single-output AFs. Reviewed By: bletham Differential Revision: D77666057 fbshipit-source-id: 5eebe75885c5256b6bcf1c9ac67499c16015e70c
1 parent e7fef3b commit 853a4c4

File tree

4 files changed

+225
-0
lines changed

4 files changed

+225
-0
lines changed

botorch/acquisition/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
SampleReducingMCAcquisitionFunction,
6565
)
6666
from botorch.acquisition.multi_step_lookahead import qMultiStepLookahead
67+
from botorch.acquisition.multioutput_acquisition import MultiOutputAcquisitionFunction
6768
from botorch.acquisition.objective import (
6869
ConstrainedMCObjective,
6970
GenericMCObjective,
@@ -136,4 +137,5 @@
136137
"ScalarizedPosteriorTransform",
137138
"get_acquisition_function",
138139
"get_acqf_input_constructor",
140+
"MultiOutputAcquisitionFunction",
139141
]
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""Abstract base module for multi-output acquisition functions."""
8+
9+
from __future__ import annotations
10+
11+
from abc import ABC, abstractmethod
12+
13+
import torch
14+
from botorch.acquisition.acquisition import AcquisitionFunction
15+
from botorch.exceptions.errors import UnsupportedError
16+
from botorch.models.model import Model
17+
from botorch.utils.transforms import (
18+
average_over_ensemble_models,
19+
t_batch_mode_transform,
20+
)
21+
from torch import Tensor
22+
23+
24+
class MultiOutputAcquisitionFunction(AcquisitionFunction, ABC):
25+
r"""Abstract base class for multi-output acquisition functions.
26+
27+
These are intended to be optimized with a multi-objective optimizer (e.g.
28+
NSGA-II).
29+
"""
30+
31+
@abstractmethod
32+
def forward(self, X: Tensor) -> Tensor:
33+
r"""Evaluate the acquisition function on the candidate set X.
34+
35+
Args:
36+
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
37+
design points each.
38+
39+
Returns:
40+
A `(b) x m`-dim Tensor of acquisition function values at the given
41+
design points `X`.
42+
"""
43+
44+
def set_X_pending(self, X_pending: Tensor | None) -> None:
45+
r"""Set the pending points.
46+
47+
Args:
48+
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points that
49+
have points that have been submitted for function evaluation
50+
(but may not yet have been evaluated).
51+
"""
52+
raise UnsupportedError(
53+
"X_pending is not supported for multi-output acquisition functions."
54+
)
55+
56+
57+
class MultiOutputPosteriorMean(MultiOutputAcquisitionFunction):
58+
def __init__(self, model: Model, weights: Tensor | None = None) -> None:
59+
r"""Constructor for the MultiPosteriorMean.
60+
61+
Maximization of all outputs is assumed by default. Minimizing outputs can
62+
be achieved by setting the corresponding weights to negative.
63+
64+
Args:
65+
acqfs: A list of `m` acquisition functions.
66+
weights: A one-dimensional tensor with `m` elements representing the
67+
weights on the outputs.
68+
"""
69+
super().__init__(model=model)
70+
if self.model.num_outputs < 2:
71+
raise NotImplementedError(
72+
"MultiPosteriorMean only supports multi-output models."
73+
)
74+
# TODO: this could be done via a posterior transform
75+
if weights is not None and weights.shape[0] != self.model.num_outputs:
76+
raise ValueError(
77+
f"weights must have {self.model.num_outputs} elements, but got"
78+
f" {weights.shape[0]}."
79+
)
80+
self.register_buffer("weights", weights)
81+
82+
@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
83+
@average_over_ensemble_models
84+
def forward(self, X: Tensor) -> Tensor:
85+
r"""Evaluate the acquisition function on the candidate set X.
86+
87+
Args:
88+
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
89+
design points each.
90+
91+
Returns:
92+
A `(b) x m`-dim Tensor of acquisition function values at the given
93+
design points `X`.
94+
"""
95+
mean = self.model.posterior(X).mean.squeeze(-2)
96+
if self.weights is not None:
97+
return mean * self.weights
98+
return mean
99+
100+
101+
class MultiOutputAcquisitionFunctionWrapper(MultiOutputAcquisitionFunction):
102+
r"""Multi-output wrapper around single-output acquisition functions."""
103+
104+
def __init__(self, acqfs: list[AcquisitionFunction]) -> None:
105+
r"""Constructor for the AcquisitionFunction base class.
106+
107+
Args:
108+
acqfs: A list of `m` acquisition functions.
109+
"""
110+
# We could set the model to be an ensemble model consistent of the
111+
# model used in each acqf
112+
super().__init__(model=acqfs[0].model)
113+
self.acqfs: list[AcquisitionFunction] = acqfs
114+
115+
def forward(self, X: Tensor) -> Tensor:
116+
r"""Evaluate the acquisition function on the candidate set X.
117+
118+
Args:
119+
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
120+
design points each.
121+
122+
Returns:
123+
A `(b) x m`-dim Tensor of acquisition function values at the given
124+
design points `X`.
125+
"""
126+
return torch.stack([acqf(X) for acqf in self.acqfs], dim=-1)

sphinx/source/acquisition.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Monte-Carlo Acquisition Function API
3737
.. autoclass:: MCAcquisitionFunction
3838
:members:
3939

40+
Multi-Output Acquisition Function API
41+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42+
.. automodule:: botorch.acquisition.multioutput_acquisition
43+
:members:
44+
4045
Base Classes for Multi-Objective Acquisition Function API
4146
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4247
.. automodule:: botorch.acquisition.multi_objective.base
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from botorch.acquisition.analytic import LogExpectedImprovement, UpperConfidenceBound
9+
from botorch.acquisition.multioutput_acquisition import (
10+
MultiOutputAcquisitionFunction,
11+
MultiOutputAcquisitionFunctionWrapper,
12+
MultiOutputPosteriorMean,
13+
)
14+
from botorch.exceptions.errors import UnsupportedError
15+
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
16+
17+
18+
class DummyMultiOutputAcqf(MultiOutputAcquisitionFunction):
19+
def forward(self, X):
20+
pass
21+
22+
23+
class TestMultiOutputAcquisitionFunction(BotorchTestCase):
24+
def test_abstract_raises(self):
25+
with self.assertRaises(TypeError):
26+
MultiOutputAcquisitionFunction()
27+
28+
def test_set_X_pending(self) -> None:
29+
with self.assertRaisesRegex(
30+
UnsupportedError,
31+
"X_pending is not supported for multi-output acquisition functions.",
32+
):
33+
DummyMultiOutputAcqf(
34+
model=MockModel(posterior=MockPosterior())
35+
).set_X_pending(torch.ones(1, 1))
36+
37+
def test_multioutput_posterior_mean(self) -> None:
38+
# test single output model
39+
with self.assertRaisesRegex(
40+
NotImplementedError, "MultiPosteriorMean only supports multi-output models."
41+
):
42+
MultiOutputPosteriorMean(
43+
model=MockModel(posterior=MockPosterior(mean=torch.tensor([[1.0]])))
44+
)
45+
# test invalid weights
46+
with self.assertRaisesRegex(
47+
ValueError, "weights must have 2 elements, but got 1."
48+
):
49+
MultiOutputPosteriorMean(
50+
model=MockModel(
51+
posterior=MockPosterior(mean=torch.tensor([[1.0, 2.0]]))
52+
),
53+
weights=torch.tensor([1.0]),
54+
)
55+
for dtype in (torch.float, torch.double):
56+
# basic test
57+
mean = torch.tensor([[1.0, 2.0]], dtype=dtype, device=self.device)
58+
acqf = MultiOutputPosteriorMean(
59+
model=MockModel(posterior=MockPosterior(mean=mean))
60+
)
61+
self.assertTrue(
62+
torch.equal(
63+
acqf(torch.ones(1, 1, 1, dtype=dtype, device=self.device)),
64+
mean.squeeze(-2),
65+
)
66+
)
67+
# test weights
68+
weights = torch.tensor([-1.0, 1.0], dtype=dtype, device=self.device)
69+
acqf = MultiOutputPosteriorMean(
70+
model=MockModel(posterior=MockPosterior(mean=mean)), weights=weights
71+
)
72+
self.assertTrue(
73+
torch.equal(
74+
acqf(torch.ones(1, 1, 1, dtype=dtype, device=self.device)),
75+
mean.squeeze(-2) * weights,
76+
)
77+
)
78+
79+
def test_multioutput_wrapper(self) -> None:
80+
for dtype in (torch.float, torch.double):
81+
model = MockModel(
82+
posterior=MockPosterior(
83+
mean=torch.tensor([[1.0]], dtype=dtype, device=self.device),
84+
variance=torch.tensor([[0.1]], dtype=dtype, device=self.device),
85+
)
86+
)
87+
ei = LogExpectedImprovement(model=model, best_f=0.0)
88+
ucb = UpperConfidenceBound(model=model, beta=2.0)
89+
acqf = MultiOutputAcquisitionFunctionWrapper(acqfs=[ei, ucb])
90+
X = torch.ones(1, 1, 1, dtype=dtype, device=self.device)
91+
expected_af_vals = torch.stack([ei(X=X), ucb(X=X)], dim=-1)
92+
self.assertTrue(torch.equal(acqf(X), expected_af_vals))

0 commit comments

Comments
 (0)