@@ -83,7 +83,7 @@ def _gen_model_single_output(**tkwargs):
8383
8484
8585def _gen_fixed_noise_model_and_data (
86- task_feature : int = 0 , input_transform = None , ** tkwargs
86+ task_feature : int = 0 , input_transform = None , outcome_transform = None , ** tkwargs
8787):
8888 datasets , (train_X , train_Y , train_Yvar ) = _gen_datasets (yvar = 0.05 , ** tkwargs )
8989 model = FixedNoiseMultiTaskGP (
@@ -92,6 +92,7 @@ def _gen_fixed_noise_model_and_data(
9292 train_Yvar ,
9393 task_feature = task_feature ,
9494 input_transform = input_transform ,
95+ outcome_transform = outcome_transform ,
9596 )
9697 return model .to (** tkwargs ), datasets , (train_X , train_Y , train_Yvar )
9798
@@ -345,17 +346,18 @@ def test_MultiTaskGP_given_covar_module(self):
345346class TestFixedNoiseMultiTaskGP (BotorchTestCase ):
346347 def test_FixedNoiseMultiTaskGP (self ):
347348 bounds = torch .tensor ([[- 1.0 , 0.0 ], [1.0 , 1.0 ]])
348- for dtype , use_intf in itertools .product (
349- (torch .float , torch .double ), (False , True )
349+ for dtype , use_intf , use_octf in itertools .product (
350+ (torch .float , torch .double ), (False , True ), ( False , True )
350351 ):
351352 tkwargs = {"device" : self .device , "dtype" : dtype }
353+ octf = Standardize (m = 1 ) if use_octf else None
352354 intf = (
353355 Normalize (d = 2 , bounds = bounds .to (** tkwargs ), transform_on_train = True )
354356 if use_intf
355357 else None
356358 )
357359 model , _ , (train_X , _ , _ ) = _gen_fixed_noise_model_and_data (
358- input_transform = intf , ** tkwargs
360+ input_transform = intf , outcome_transform = octf , ** tkwargs
359361 )
360362 self .assertIsInstance (model , FixedNoiseMultiTaskGP )
361363 self .assertEqual (model .num_outputs , 2 )
@@ -370,6 +372,8 @@ def test_FixedNoiseMultiTaskGP(self):
370372 self .assertEqual (
371373 model .task_covar_module .covar_factor .shape [- 1 ], model ._rank
372374 )
375+ if use_octf :
376+ self .assertIsInstance (model .outcome_transform , Standardize )
373377 if use_intf :
374378 self .assertIsInstance (model .input_transform , Normalize )
375379
@@ -394,6 +398,16 @@ def test_FixedNoiseMultiTaskGP(self):
394398 self .assertEqual (posterior_f .mean .shape , torch .Size ([2 , 2 ]))
395399 self .assertEqual (posterior_f .variance .shape , torch .Size ([2 , 2 ]))
396400
401+ # check posterior transform is applied
402+ if use_octf :
403+ posterior_pred = model .posterior (test_x )
404+ tmp_tf = model .outcome_transform
405+ del model .outcome_transform
406+ pp_tf = model .posterior (test_x )
407+ model .outcome_transform = tmp_tf
408+ expected_var = tmp_tf .untransform_posterior (pp_tf ).variance
409+ self .assertTrue (torch .allclose (posterior_pred .variance , expected_var ))
410+
397411 # test that posterior w/ observation noise raises appropriate error
398412 with self .assertRaises (NotImplementedError ):
399413 model .posterior (test_x , observation_noise = True )
0 commit comments