Skip to content

Commit 1ed307c

Browse files
qingfeng10facebook-github-bot
authored andcommitted
Make task_feature as required input in MultiTaskGP.construct_inputs (#1246)
Summary: Pull Request resolved: #1246 Motivation: We see an error in CV when using ListSurrogate (a list of MTGP) in Models.BOTORCH_MODULAR. See N2054031 for the full error trace. The error happens when loading state_dict. The reason is that a wrong task_feature is passed to MTGP (use default value = 0) and thus, the task_covar_module has a wrong shape. The proposed the fix here is: Make task_feature as required input in the MultiTaskGP.construct_inputs Reviewed By: Balandat Differential Revision: D36925653 fbshipit-source-id: 8cfde439083bb47d141f6cfc872646ffce7838de
1 parent b2562fe commit 1ed307c

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

botorch/models/multitask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def get_all_tasks(
239239
def construct_inputs(
240240
cls,
241241
training_data: Dict[str, SupervisedDataset],
242-
task_feature: int = 0,
242+
task_feature: int,
243243
task_covar_prior: Optional[Prior] = None,
244244
prior_config: Optional[dict] = None,
245245
rank: Optional[int] = None,

test/models/test_multitask.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ def _gen_datasets(yvar: Optional[float] = None, **tkwargs):
6262
return datasets, (train_X, train_Y, train_Yvar)
6363

6464

65-
def _gen_model_and_data(input_transform=None, outcome_transform=None, **tkwargs):
65+
def _gen_model_and_data(
66+
task_feature: int = 0, input_transform=None, outcome_transform=None, **tkwargs
67+
):
6668
datasets, (train_X, train_Y) = _gen_datasets(**tkwargs)
6769
model = MultiTaskGP(
6870
train_X,
6971
train_Y,
70-
task_feature=0,
72+
task_feature=task_feature,
7173
input_transform=input_transform,
7274
outcome_transform=outcome_transform,
7375
)
@@ -80,10 +82,16 @@ def _gen_model_single_output(**tkwargs):
8082
return model.to(**tkwargs)
8183

8284

83-
def _gen_fixed_noise_model_and_data(input_transform=None, **tkwargs):
85+
def _gen_fixed_noise_model_and_data(
86+
task_feature: int = 0, input_transform=None, **tkwargs
87+
):
8488
datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(yvar=0.05, **tkwargs)
8589
model = FixedNoiseMultiTaskGP(
86-
train_X, train_Y, train_Yvar, task_feature=0, input_transform=input_transform
90+
train_X,
91+
train_Y,
92+
train_Yvar,
93+
task_feature=task_feature,
94+
input_transform=input_transform,
8795
)
8896
return model.to(**tkwargs), datasets, (train_X, train_Y, train_Yvar)
8997

@@ -488,69 +496,84 @@ def test_FixedNoiseMultiTaskGP_given_covar_module(self):
488496
def test_MultiTaskGP_construct_inputs(self):
489497
for dtype in (torch.float, torch.double):
490498
tkwargs = {"device": self.device, "dtype": dtype}
491-
model, datasets, (train_X, train_Y) = _gen_model_and_data(**tkwargs)
499+
task_feature = 0
500+
model, datasets, (train_X, train_Y) = _gen_model_and_data(
501+
task_feature=task_feature, **tkwargs
502+
)
492503

493504
# Validate prior config.
494505
with self.assertRaisesRegex(
495506
ValueError, ".* only config for LKJ prior is supported"
496507
):
497508
data_dict = model.construct_inputs(
498509
datasets,
510+
task_feature=task_feature,
499511
prior_config={"use_LKJ_prior": False},
500512
)
501513
# Validate eta.
502514
with self.assertRaisesRegex(ValueError, "eta must be a real number"):
503515
data_dict = model.construct_inputs(
504516
datasets,
517+
task_feature=task_feature,
505518
prior_config={"use_LKJ_prior": True, "eta": "not_number"},
506519
)
507520
# Test that presence of `prior` and `prior_config` kwargs at the
508521
# same time causes error.
509522
with self.assertRaisesRegex(ValueError, "Only one of"):
510523
data_dict = model.construct_inputs(
511524
datasets,
525+
task_feature=task_feature,
512526
task_covar_prior=1,
513527
prior_config={"use_LKJ_prior": True, "eta": "not_number"},
514528
)
515529
data_dict = model.construct_inputs(
516530
datasets,
531+
task_feature=task_feature,
517532
prior_config={"use_LKJ_prior": True, "eta": 0.6},
518533
)
519534
self.assertTrue(torch.equal(data_dict["train_X"], train_X))
520535
self.assertTrue(torch.equal(data_dict["train_Y"], train_Y))
521-
self.assertEqual(data_dict["task_feature"], 0)
536+
self.assertEqual(data_dict["task_feature"], task_feature)
522537
self.assertIsInstance(data_dict["task_covar_prior"], LKJCovariancePrior)
523538

524539
def test_FixedNoiseMultiTaskGP_construct_inputs(self):
525540
for dtype in (torch.float, torch.double):
526541
tkwargs = {"device": self.device, "dtype": dtype}
542+
task_feature = 0
527543

528544
(
529545
model,
530546
datasets,
531547
(train_X, train_Y, train_Yvar),
532-
) = _gen_fixed_noise_model_and_data(**tkwargs)
548+
) = _gen_fixed_noise_model_and_data(task_feature=task_feature, **tkwargs)
533549

534550
# Test only one of `task_covar_prior` and `prior_config` can be passed.
535551
with self.assertRaisesRegex(ValueError, "Only one of"):
536-
model.construct_inputs(datasets, task_covar_prior=1, prior_config=1)
552+
model.construct_inputs(
553+
datasets,
554+
task_feature=task_feature,
555+
task_covar_prior=1,
556+
prior_config=1,
557+
)
537558

538559
# Validate prior config.
539560
with self.assertRaisesRegex(
540561
ValueError, ".* only config for LKJ prior is supported"
541562
):
542563
data_dict = model.construct_inputs(
543564
datasets,
565+
task_feature=task_feature,
544566
prior_config={"use_LKJ_prior": False},
545567
)
546568
data_dict = model.construct_inputs(
547569
datasets,
570+
task_feature=task_feature,
548571
prior_config={"use_LKJ_prior": True, "eta": 0.6},
549572
)
550573
self.assertTrue(torch.equal(data_dict["train_X"], train_X))
551574
self.assertTrue(torch.equal(data_dict["train_Y"], train_Y))
552575
self.assertTrue(torch.allclose(data_dict["train_Yvar"], train_Yvar))
553-
self.assertEqual(data_dict["task_feature"], 0)
576+
self.assertEqual(data_dict["task_feature"], task_feature)
554577
self.assertIsInstance(data_dict["task_covar_prior"], LKJCovariancePrior)
555578

556579

0 commit comments

Comments
 (0)