1919 batched_to_model_list ,
2020 model_list_to_batched ,
2121)
22- from botorch .models .transforms .input import Normalize
22+ from botorch .models .transforms .input import AppendFeatures , Normalize
2323from botorch .models .transforms .outcome import Standardize
2424from botorch .utils .testing import BotorchTestCase
2525from gpytorch .likelihoods import GaussianLikelihood
@@ -80,6 +80,16 @@ def test_batched_to_model_list(self):
8080 expected_octf .__getattr__ (attr_name ),
8181 )
8282 )
83+ # test with AppendFeatures
84+ input_tf = AppendFeatures (
85+ feature_set = torch .rand (2 , 1 , device = self .device , dtype = dtype )
86+ )
87+ batch_gp = SingleTaskGP (
88+ train_X , train_Y , outcome_transform = octf , input_transform = input_tf
89+ ).eval ()
90+ list_gp = batched_to_model_list (batch_gp )
91+ self .assertIsInstance (list_gp , ModelListGP )
92+ self .assertIsInstance (list_gp .models [0 ].input_transform , AppendFeatures )
8393
8494 def test_model_list_to_batched (self ):
8595 for dtype in (torch .float , torch .double ):
@@ -167,6 +177,16 @@ def test_model_list_to_batched(self):
167177 self .assertTrue (
168178 torch .equal (batch_gp .input_transform .bounds , input_tf .bounds )
169179 )
180+ # test with AppendFeatures
181+ input_tf3 = AppendFeatures (
182+ feature_set = torch .rand (2 , 1 , device = self .device , dtype = dtype )
183+ )
184+ gp1_ = SingleTaskGP (train_X , train_Y1 , input_transform = input_tf3 )
185+ gp2_ = SingleTaskGP (train_X , train_Y2 , input_transform = input_tf3 )
186+ list_gp = ModelListGP (gp1_ , gp2_ ).eval ()
187+ batch_gp = model_list_to_batched (list_gp )
188+ self .assertIsInstance (batch_gp , SingleTaskGP )
189+ self .assertIsInstance (batch_gp .input_transform , AppendFeatures )
170190 # test different input transforms
171191 input_tf2 = Normalize (
172192 d = 2 ,
@@ -177,7 +197,7 @@ def test_model_list_to_batched(self):
177197 gp1_ = SingleTaskGP (train_X , train_Y1 , input_transform = input_tf )
178198 gp2_ = SingleTaskGP (train_X , train_Y2 , input_transform = input_tf2 )
179199 list_gp = ModelListGP (gp1_ , gp2_ )
180- with self .assertRaises (UnsupportedError ):
200+ with self .assertRaisesRegex (UnsupportedError , "have the same" ):
181201 model_list_to_batched (list_gp )
182202
183203 # test batched input transform
@@ -292,17 +312,26 @@ def test_batched_multi_output_to_single_output(self):
292312 self .assertTrue (
293313 torch .equal (batch_so_model .input_transform .bounds , input_tf .bounds )
294314 )
315+ # test with AppendFeatures
316+ input_tf = AppendFeatures (
317+ feature_set = torch .rand (2 , 1 , device = self .device , dtype = dtype )
318+ )
319+ batched_mo_model = SingleTaskGP (
320+ train_X , train_Y , input_transform = input_tf
321+ ).eval ()
322+ batch_so_model = batched_multi_output_to_single_output (batched_mo_model )
323+ self .assertIsInstance (batch_so_model .input_transform , AppendFeatures )
295324
296325 # test batched input transform
297- input_tf2 = Normalize (
326+ input_tf = Normalize (
298327 d = 2 ,
299328 bounds = torch .tensor (
300329 [[- 1.0 , - 1.0 ], [1.0 , 1.0 ]], device = self .device , dtype = dtype
301330 ),
302331 batch_shape = torch .Size ([2 ]),
303332 )
304- batched_mo_model = SingleTaskGP (train_X , train_Y , input_transform = input_tf2 )
305- batched_so_model = batched_multi_output_to_single_output (batched_mo_model )
333+ batched_mo_model = SingleTaskGP (train_X , train_Y , input_transform = input_tf )
334+ batch_so_model = batched_multi_output_to_single_output (batched_mo_model )
306335 self .assertIsInstance (batch_so_model .input_transform , Normalize )
307336 self .assertTrue (
308337 torch .equal (batch_so_model .input_transform .bounds , input_tf .bounds )
0 commit comments