1717from botorch .models .transforms import Standardize
1818from botorch .models .transforms .input import Normalize
1919from botorch .posteriors import GPyTorchPosterior
20+ from botorch .sampling .samplers import IIDNormalSampler
2021from botorch .utils .testing import _get_random_data , BotorchTestCase
2122from gpytorch .distributions import MultitaskMultivariateNormal , MultivariateNormal
2223from gpytorch .kernels import MaternKernel , ScaleKernel
@@ -142,19 +143,19 @@ def test_ModelListGP(self):
142143 self .assertIsInstance (posterior .mvn , MultivariateNormal )
143144
144145 # test condition_on_observations
145- f_x = torch .rand (2 , 1 , ** tkwargs )
146+ f_x = [ torch .rand (2 , 1 , ** tkwargs ) for _ in range ( 2 )]
146147 f_y = torch .rand (2 , 2 , ** tkwargs )
147148 cm = model .condition_on_observations (f_x , f_y )
148149 self .assertIsInstance (cm , ModelListGP )
149150
150151 # test condition_on_observations batched
151- f_x = torch .rand (3 , 2 , 1 , ** tkwargs )
152+ f_x = [ torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range ( 2 )]
152153 f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
153154 cm = model .condition_on_observations (f_x , f_y )
154155 self .assertIsInstance (cm , ModelListGP )
155156
156157 # test condition_on_observations batched (fast fantasies)
157- f_x = torch .rand (2 , 1 , ** tkwargs )
158+ f_x = [ torch .rand (2 , 1 , ** tkwargs ) for _ in range ( 2 )]
158159 f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
159160 cm = model .condition_on_observations (f_x , f_y )
160161 self .assertIsInstance (cm , ModelListGP )
@@ -163,6 +164,10 @@ def test_ModelListGP(self):
163164 with self .assertRaises (BotorchTensorDimensionError ):
164165 model .condition_on_observations (f_x , torch .rand (3 , 2 , 3 , ** tkwargs ))
165166
167+ # test X having wrong size
168+ with self .assertRaises (AssertionError ):
169+ cm = model .condition_on_observations (f_x [:1 ], f_y )
170+
166171 # test posterior transform
167172 X = torch .rand (3 , 1 , ** tkwargs )
168173 weights = torch .tensor ([1 , 2 ], ** tkwargs )
@@ -222,21 +227,21 @@ def test_ModelListGP_fixed_noise(self):
222227 self .assertIsInstance (posterior .mvn , MultivariateNormal )
223228
224229 # test condition_on_observations
225- f_x = torch .rand (2 , 1 , ** tkwargs )
230+ f_x = [ torch .rand (2 , 1 , ** tkwargs ) for _ in range ( 2 )]
226231 f_y = torch .rand (2 , 2 , ** tkwargs )
227232 noise = 0.1 + 0.1 * torch .rand_like (f_y )
228233 cm = model .condition_on_observations (f_x , f_y , noise = noise )
229234 self .assertIsInstance (cm , ModelListGP )
230235
231236 # test condition_on_observations batched
232- f_x = torch .rand (3 , 2 , 1 , ** tkwargs )
237+ f_x = [ torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range ( 2 )]
233238 f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
234239 noise = 0.1 + 0.1 * torch .rand_like (f_y )
235240 cm = model .condition_on_observations (f_x , f_y , noise = noise )
236241 self .assertIsInstance (cm , ModelListGP )
237242
238243 # test condition_on_observations batched (fast fantasies)
239- f_x = torch .rand (2 , 1 , ** tkwargs )
244+ f_x = [ torch .rand (2 , 1 , ** tkwargs ) for _ in range ( 2 )]
240245 f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
241246 noise = 0.1 + 0.1 * torch .rand (2 , 2 , ** tkwargs )
242247 cm = model .condition_on_observations (f_x , f_y , noise = noise )
@@ -295,3 +300,15 @@ def test_transform_revert_train_inputs(self):
295300 )
296301 self .assertTrue (m ._has_transformed_inputs )
297302 self .assertTrue (torch .equal (m ._original_train_inputs , org_inputs [i ]))
303+
304+ def test_fantasize (self ):
305+ m1 = SingleTaskGP (torch .rand (5 , 2 ), torch .rand (5 , 1 )).eval ()
306+ m2 = SingleTaskGP (torch .rand (5 , 2 ), torch .rand (5 , 1 )).eval ()
307+ modellist = ModelListGP (m1 , m2 )
308+ fm = modellist .fantasize (torch .rand (3 , 2 ), sampler = IIDNormalSampler (2 ))
309+ self .assertIsInstance (fm , ModelListGP )
310+ for i in range (2 ):
311+ fm_i = fm .models [i ]
312+ self .assertIsInstance (fm_i , SingleTaskGP )
313+ self .assertEqual (fm_i .train_inputs [0 ].shape , torch .Size ([2 , 8 , 2 ]))
314+ self .assertEqual (fm_i .train_targets .shape , torch .Size ([2 , 8 ]))
0 commit comments