Skip to content

Commit b2562fe

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix ModelListGP.condition_on_observations/fantasize bug (#1250)
Summary: Pull Request resolved: #1250 `ModelListGP.fantasize` would fail due to `X` being transformed into a list of `X` in `fantasize`, then again being made into a list of `X`s in `condition_on_observations`. This removes the duplication of `X` in condition on observations, using the correctly transformed `X`s for fantasizing. Fixes #1247 Reviewed By: Balandat Differential Revision: D36949105 fbshipit-source-id: 3a8a3deb4b78416b8a67cd70863298b009f04209
1 parent 2d011fa commit b2562fe

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed

botorch/models/model_list_gp_regression.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ def __init__(self, *gp_models: GPyTorchModel) -> None:
4949
super().__init__(*gp_models)
5050

5151
def condition_on_observations(
52-
self, X: Tensor, Y: Tensor, **kwargs: Any
52+
self, X: List[Tensor], Y: Tensor, **kwargs: Any
5353
) -> ModelListGP:
5454
r"""Condition the model on new observations.
5555
5656
Args:
57-
X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of
58-
the feature space, `n'` is the number of points per batch, and
59-
`batch_shape` is the batch shape (must be compatible with the
60-
batch shape of the model).
57+
X: A `m`-list of `batch_shape x n' x d`-dim Tensors, where `d` is the
58+
dimension of the feature space, `n'` is the number of points
59+
per batch, and `batch_shape` is the batch shape (must be compatible
60+
with the batch shape of the model).
6161
Y: A `batch_shape' x n' x m`-dim Tensor, where `m` is the number of
6262
model outputs, `n'` is the number of points per batch, and
6363
`batch_shape'` is the batch shape of the observations.
@@ -73,24 +73,26 @@ def condition_on_observations(
7373
`n_i + n'` training examples, where the `n'` training examples have
7474
been added and all test-time caches have been updated.
7575
"""
76-
self._validate_tensor_args(
77-
X=X, Y=Y, Yvar=kwargs.get("noise", None), strict=False
78-
)
79-
inputs = [X] * self.num_outputs
8076
if Y.shape[-1] != self.num_outputs:
8177
raise BotorchTensorDimensionError(
8278
"Incorrect number of outputs for observations. Received "
8379
f"{Y.shape[-1]} observation outputs, but model has "
8480
f"{self.num_outputs} outputs."
8581
)
8682
targets = [Y[..., i] for i in range(Y.shape[-1])]
83+
# This should never trigger, posterior call would fail.
84+
assert len(targets) == len(X)
8785
if "noise" in kwargs:
8886
noise = kwargs.pop("noise")
89-
# Note: dimension checks were performed in _validate_tensor_args
87+
if noise.shape != Y.shape[-noise.dim() :]:
88+
raise BotorchTensorDimensionError(
89+
"The shape of observation noise does not agree with the outcomes. "
90+
f"Received {noise.shape} noise with {Y.shape} outcomes."
91+
)
9092
kwargs_ = {**kwargs, "noise": [noise[..., i] for i in range(Y.shape[-1])]}
9193
else:
9294
kwargs_ = kwargs
93-
return super().get_fantasy_model(inputs, targets, **kwargs_)
95+
return super().get_fantasy_model(X, targets, **kwargs_)
9496

9597
def subset_output(self, idcs: List[int]) -> ModelListGP:
9698
r"""Subset the model along the output dimension.

test/models/test_gpytorch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,13 @@ def test_validate_tensor_args(self):
215215
BotorchTensorDimensionWarning
216216
):
217217
GPyTorchModel._validate_tensor_args(X, Y[0], strict=False)
218+
# with Yvar
219+
if len(output_dim_shape) > 0:
220+
Yvar = torch.empty(torch.Size([n]) + output_dim_shape, **tkwargs)
221+
GPyTorchModel._validate_tensor_args(X, Y, Yvar)
222+
Yvar = torch.empty(n, 5, **tkwargs)
223+
with self.assertRaises(BotorchTensorDimensionError):
224+
GPyTorchModel._validate_tensor_args(X, Y, Yvar)
218225

219226
def test_fantasize_flag(self):
220227
train_X = torch.rand(5, 1)

test/models/test_model_list_gp_regression.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.models.transforms import Standardize
1818
from botorch.models.transforms.input import Normalize
1919
from botorch.posteriors import GPyTorchPosterior
20+
from botorch.sampling.samplers import IIDNormalSampler
2021
from botorch.utils.testing import _get_random_data, BotorchTestCase
2122
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
2223
from gpytorch.kernels import MaternKernel, ScaleKernel
@@ -142,19 +143,19 @@ def test_ModelListGP(self):
142143
self.assertIsInstance(posterior.mvn, MultivariateNormal)
143144

144145
# test condition_on_observations
145-
f_x = torch.rand(2, 1, **tkwargs)
146+
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
146147
f_y = torch.rand(2, 2, **tkwargs)
147148
cm = model.condition_on_observations(f_x, f_y)
148149
self.assertIsInstance(cm, ModelListGP)
149150

150151
# test condition_on_observations batched
151-
f_x = torch.rand(3, 2, 1, **tkwargs)
152+
f_x = [torch.rand(3, 2, 1, **tkwargs) for _ in range(2)]
152153
f_y = torch.rand(3, 2, 2, **tkwargs)
153154
cm = model.condition_on_observations(f_x, f_y)
154155
self.assertIsInstance(cm, ModelListGP)
155156

156157
# test condition_on_observations batched (fast fantasies)
157-
f_x = torch.rand(2, 1, **tkwargs)
158+
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
158159
f_y = torch.rand(3, 2, 2, **tkwargs)
159160
cm = model.condition_on_observations(f_x, f_y)
160161
self.assertIsInstance(cm, ModelListGP)
@@ -163,6 +164,10 @@ def test_ModelListGP(self):
163164
with self.assertRaises(BotorchTensorDimensionError):
164165
model.condition_on_observations(f_x, torch.rand(3, 2, 3, **tkwargs))
165166

167+
# test X having wrong size
168+
with self.assertRaises(AssertionError):
169+
cm = model.condition_on_observations(f_x[:1], f_y)
170+
166171
# test posterior transform
167172
X = torch.rand(3, 1, **tkwargs)
168173
weights = torch.tensor([1, 2], **tkwargs)
@@ -222,21 +227,21 @@ def test_ModelListGP_fixed_noise(self):
222227
self.assertIsInstance(posterior.mvn, MultivariateNormal)
223228

224229
# test condition_on_observations
225-
f_x = torch.rand(2, 1, **tkwargs)
230+
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
226231
f_y = torch.rand(2, 2, **tkwargs)
227232
noise = 0.1 + 0.1 * torch.rand_like(f_y)
228233
cm = model.condition_on_observations(f_x, f_y, noise=noise)
229234
self.assertIsInstance(cm, ModelListGP)
230235

231236
# test condition_on_observations batched
232-
f_x = torch.rand(3, 2, 1, **tkwargs)
237+
f_x = [torch.rand(3, 2, 1, **tkwargs) for _ in range(2)]
233238
f_y = torch.rand(3, 2, 2, **tkwargs)
234239
noise = 0.1 + 0.1 * torch.rand_like(f_y)
235240
cm = model.condition_on_observations(f_x, f_y, noise=noise)
236241
self.assertIsInstance(cm, ModelListGP)
237242

238243
# test condition_on_observations batched (fast fantasies)
239-
f_x = torch.rand(2, 1, **tkwargs)
244+
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
240245
f_y = torch.rand(3, 2, 2, **tkwargs)
241246
noise = 0.1 + 0.1 * torch.rand(2, 2, **tkwargs)
242247
cm = model.condition_on_observations(f_x, f_y, noise=noise)
@@ -295,3 +300,15 @@ def test_transform_revert_train_inputs(self):
295300
)
296301
self.assertTrue(m._has_transformed_inputs)
297302
self.assertTrue(torch.equal(m._original_train_inputs, org_inputs[i]))
303+
304+
def test_fantasize(self):
305+
m1 = SingleTaskGP(torch.rand(5, 2), torch.rand(5, 1)).eval()
306+
m2 = SingleTaskGP(torch.rand(5, 2), torch.rand(5, 1)).eval()
307+
modellist = ModelListGP(m1, m2)
308+
fm = modellist.fantasize(torch.rand(3, 2), sampler=IIDNormalSampler(2))
309+
self.assertIsInstance(fm, ModelListGP)
310+
for i in range(2):
311+
fm_i = fm.models[i]
312+
self.assertIsInstance(fm_i, SingleTaskGP)
313+
self.assertEqual(fm_i.train_inputs[0].shape, torch.Size([2, 8, 2]))
314+
self.assertEqual(fm_i.train_targets.shape, torch.Size([2, 8]))

0 commit comments

Comments
 (0)