@@ -62,12 +62,14 @@ def _gen_datasets(yvar: Optional[float] = None, **tkwargs):
6262 return datasets , (train_X , train_Y , train_Yvar )
6363
6464
65- def _gen_model_and_data (input_transform = None , outcome_transform = None , ** tkwargs ):
65+ def _gen_model_and_data (
66+ task_feature : int = 0 , input_transform = None , outcome_transform = None , ** tkwargs
67+ ):
6668 datasets , (train_X , train_Y ) = _gen_datasets (** tkwargs )
6769 model = MultiTaskGP (
6870 train_X ,
6971 train_Y ,
70- task_feature = 0 ,
72+ task_feature = task_feature ,
7173 input_transform = input_transform ,
7274 outcome_transform = outcome_transform ,
7375 )
@@ -80,10 +82,16 @@ def _gen_model_single_output(**tkwargs):
8082 return model .to (** tkwargs )
8183
8284
83- def _gen_fixed_noise_model_and_data (input_transform = None , ** tkwargs ):
85+ def _gen_fixed_noise_model_and_data (
86+ task_feature : int = 0 , input_transform = None , ** tkwargs
87+ ):
8488 datasets , (train_X , train_Y , train_Yvar ) = _gen_datasets (yvar = 0.05 , ** tkwargs )
8589 model = FixedNoiseMultiTaskGP (
86- train_X , train_Y , train_Yvar , task_feature = 0 , input_transform = input_transform
90+ train_X ,
91+ train_Y ,
92+ train_Yvar ,
93+ task_feature = task_feature ,
94+ input_transform = input_transform ,
8795 )
8896 return model .to (** tkwargs ), datasets , (train_X , train_Y , train_Yvar )
8997
@@ -488,69 +496,84 @@ def test_FixedNoiseMultiTaskGP_given_covar_module(self):
488496 def test_MultiTaskGP_construct_inputs (self ):
489497 for dtype in (torch .float , torch .double ):
490498 tkwargs = {"device" : self .device , "dtype" : dtype }
491- model , datasets , (train_X , train_Y ) = _gen_model_and_data (** tkwargs )
499+ task_feature = 0
500+ model , datasets , (train_X , train_Y ) = _gen_model_and_data (
501+ task_feature = task_feature , ** tkwargs
502+ )
492503
493504 # Validate prior config.
494505 with self .assertRaisesRegex (
495506 ValueError , ".* only config for LKJ prior is supported"
496507 ):
497508 data_dict = model .construct_inputs (
498509 datasets ,
510+ task_feature = task_feature ,
499511 prior_config = {"use_LKJ_prior" : False },
500512 )
501513 # Validate eta.
502514 with self .assertRaisesRegex (ValueError , "eta must be a real number" ):
503515 data_dict = model .construct_inputs (
504516 datasets ,
517+ task_feature = task_feature ,
505518 prior_config = {"use_LKJ_prior" : True , "eta" : "not_number" },
506519 )
507520 # Test that presence of `prior` and `prior_config` kwargs at the
508521 # same time causes error.
509522 with self .assertRaisesRegex (ValueError , "Only one of" ):
510523 data_dict = model .construct_inputs (
511524 datasets ,
525+ task_feature = task_feature ,
512526 task_covar_prior = 1 ,
513527 prior_config = {"use_LKJ_prior" : True , "eta" : "not_number" },
514528 )
515529 data_dict = model .construct_inputs (
516530 datasets ,
531+ task_feature = task_feature ,
517532 prior_config = {"use_LKJ_prior" : True , "eta" : 0.6 },
518533 )
519534 self .assertTrue (torch .equal (data_dict ["train_X" ], train_X ))
520535 self .assertTrue (torch .equal (data_dict ["train_Y" ], train_Y ))
521- self .assertEqual (data_dict ["task_feature" ], 0 )
536+ self .assertEqual (data_dict ["task_feature" ], task_feature )
522537 self .assertIsInstance (data_dict ["task_covar_prior" ], LKJCovariancePrior )
523538
524539 def test_FixedNoiseMultiTaskGP_construct_inputs (self ):
525540 for dtype in (torch .float , torch .double ):
526541 tkwargs = {"device" : self .device , "dtype" : dtype }
542+ task_feature = 0
527543
528544 (
529545 model ,
530546 datasets ,
531547 (train_X , train_Y , train_Yvar ),
532- ) = _gen_fixed_noise_model_and_data (** tkwargs )
548+ ) = _gen_fixed_noise_model_and_data (task_feature = task_feature , ** tkwargs )
533549
534550 # Test only one of `task_covar_prior` and `prior_config` can be passed.
535551 with self .assertRaisesRegex (ValueError , "Only one of" ):
536- model .construct_inputs (datasets , task_covar_prior = 1 , prior_config = 1 )
552+ model .construct_inputs (
553+ datasets ,
554+ task_feature = task_feature ,
555+ task_covar_prior = 1 ,
556+ prior_config = 1 ,
557+ )
537558
538559 # Validate prior config.
539560 with self .assertRaisesRegex (
540561 ValueError , ".* only config for LKJ prior is supported"
541562 ):
542563 data_dict = model .construct_inputs (
543564 datasets ,
565+ task_feature = task_feature ,
544566 prior_config = {"use_LKJ_prior" : False },
545567 )
546568 data_dict = model .construct_inputs (
547569 datasets ,
570+ task_feature = task_feature ,
548571 prior_config = {"use_LKJ_prior" : True , "eta" : 0.6 },
549572 )
550573 self .assertTrue (torch .equal (data_dict ["train_X" ], train_X ))
551574 self .assertTrue (torch .equal (data_dict ["train_Y" ], train_Y ))
552575 self .assertTrue (torch .allclose (data_dict ["train_Yvar" ], train_Yvar ))
553- self .assertEqual (data_dict ["task_feature" ], 0 )
576+ self .assertEqual (data_dict ["task_feature" ], task_feature )
554577 self .assertIsInstance (data_dict ["task_covar_prior" ], LKJCovariancePrior )
555578
556579
0 commit comments