Skip to content

Commit a5b075b

Browse files
ItsMrLinfacebook-github-bot
authored andcommitted
Add outcome_transform to FixedNoiseMultiTaskGP (#1255)
Summary: Pull Request resolved: #1255 Add outcome_transform to FixedNoiseMultiTaskGP Reviewed By: qingfeng10 Differential Revision: D37020585 fbshipit-source-id: a2613c4bc34f8aa10018be799dd3ff6ec38dda08
1 parent 3805046 commit a5b075b

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

botorch/models/multitask.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def __init__(
113113
on p.s.d. matrices. A common prior for this is the `LKJ` prior.
114114
input_transform: An input transform that is applied in the model's
115115
forward pass.
116+
outcome_transform: An outcome transform that is applied to the
117+
training data during instantiation and to the posterior during
118+
inference (that is, the `Posterior` obtained by calling
119+
`.posterior` on the model will be on the original scale).
116120
117121
Example:
118122
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -308,6 +312,7 @@ def __init__(
308312
output_tasks: Optional[List[int]] = None,
309313
rank: Optional[int] = None,
310314
input_transform: Optional[InputTransform] = None,
315+
outcome_transform: Optional[OutcomeTransform] = None,
311316
) -> None:
312317
r"""Multi-Task GP model using an ICM kernel and known observation noise.
313318
@@ -328,6 +333,10 @@ def __init__(
328333
full rank (i.e. number of tasks) kernel.
329334
input_transform: An input transform that is applied in the model's
330335
forward pass.
336+
outcome_transform: An outcome transform that is applied to the
337+
training data during instantiation and to the posterior during
338+
inference (that is, the `Posterior` obtained by calling
339+
`.posterior` on the model will be on the original scale).
331340
332341
Example:
333342
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -344,6 +353,10 @@ def __init__(
344353
X=train_X, input_transform=input_transform
345354
)
346355
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
356+
357+
if outcome_transform is not None:
358+
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
359+
347360
# We'll instatiate a MultiTaskGP and simply override the likelihood
348361
super().__init__(
349362
train_X=train_X,
@@ -354,7 +367,11 @@ def __init__(
354367
rank=rank,
355368
task_covar_prior=task_covar_prior,
356369
input_transform=input_transform,
370+
outcome_transform=None, # outcome_transform is applied already
357371
)
372+
373+
if outcome_transform is not None:
374+
self.outcome_transform = outcome_transform
358375
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))
359376
self.to(train_X)
360377

test/models/test_multitask.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _gen_model_single_output(**tkwargs):
8383

8484

8585
def _gen_fixed_noise_model_and_data(
86-
task_feature: int = 0, input_transform=None, **tkwargs
86+
task_feature: int = 0, input_transform=None, outcome_transform=None, **tkwargs
8787
):
8888
datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(yvar=0.05, **tkwargs)
8989
model = FixedNoiseMultiTaskGP(
@@ -92,6 +92,7 @@ def _gen_fixed_noise_model_and_data(
9292
train_Yvar,
9393
task_feature=task_feature,
9494
input_transform=input_transform,
95+
outcome_transform=outcome_transform,
9596
)
9697
return model.to(**tkwargs), datasets, (train_X, train_Y, train_Yvar)
9798

@@ -345,17 +346,18 @@ def test_MultiTaskGP_given_covar_module(self):
345346
class TestFixedNoiseMultiTaskGP(BotorchTestCase):
346347
def test_FixedNoiseMultiTaskGP(self):
347348
bounds = torch.tensor([[-1.0, 0.0], [1.0, 1.0]])
348-
for dtype, use_intf in itertools.product(
349-
(torch.float, torch.double), (False, True)
349+
for dtype, use_intf, use_octf in itertools.product(
350+
(torch.float, torch.double), (False, True), (False, True)
350351
):
351352
tkwargs = {"device": self.device, "dtype": dtype}
353+
octf = Standardize(m=1) if use_octf else None
352354
intf = (
353355
Normalize(d=2, bounds=bounds.to(**tkwargs), transform_on_train=True)
354356
if use_intf
355357
else None
356358
)
357359
model, _, (train_X, _, _) = _gen_fixed_noise_model_and_data(
358-
input_transform=intf, **tkwargs
360+
input_transform=intf, outcome_transform=octf, **tkwargs
359361
)
360362
self.assertIsInstance(model, FixedNoiseMultiTaskGP)
361363
self.assertEqual(model.num_outputs, 2)
@@ -370,6 +372,8 @@ def test_FixedNoiseMultiTaskGP(self):
370372
self.assertEqual(
371373
model.task_covar_module.covar_factor.shape[-1], model._rank
372374
)
375+
if use_octf:
376+
self.assertIsInstance(model.outcome_transform, Standardize)
373377
if use_intf:
374378
self.assertIsInstance(model.input_transform, Normalize)
375379

@@ -394,6 +398,16 @@ def test_FixedNoiseMultiTaskGP(self):
394398
self.assertEqual(posterior_f.mean.shape, torch.Size([2, 2]))
395399
self.assertEqual(posterior_f.variance.shape, torch.Size([2, 2]))
396400

401+
# check posterior transform is applied
402+
if use_octf:
403+
posterior_pred = model.posterior(test_x)
404+
tmp_tf = model.outcome_transform
405+
del model.outcome_transform
406+
pp_tf = model.posterior(test_x)
407+
model.outcome_transform = tmp_tf
408+
expected_var = tmp_tf.untransform_posterior(pp_tf).variance
409+
self.assertTrue(torch.allclose(posterior_pred.variance, expected_var))
410+
397411
# test that posterior w/ observation noise raises appropriate error
398412
with self.assertRaises(NotImplementedError):
399413
model.posterior(test_x, observation_noise=True)

0 commit comments

Comments
 (0)