1414from botorch .fit import fit_gpytorch_mll
1515from botorch .models import ModelListGP
1616from botorch .models .gp_regression import FixedNoiseGP , SingleTaskGP
17- from botorch .models .transforms import Standardize
1817from botorch .models .transforms .input import Normalize
19- from botorch .posteriors import GPyTorchPosterior
18+ from botorch .models .transforms .outcome import ChainedOutcomeTransform , Log , Standardize
19+ from botorch .posteriors import GPyTorchPosterior , PosteriorList , TransformedPosterior
2020from botorch .sampling .normal import IIDNormalSampler
2121from botorch .utils .testing import _get_random_data , BotorchTestCase
2222from gpytorch .distributions import MultitaskMultivariateNormal , MultivariateNormal
2828from gpytorch .priors import GammaPrior
2929
3030
31- def _get_model (fixed_noise = False , use_octf = False , use_intf = False , ** tkwargs ):
31+ def _get_model (
32+ fixed_noise = False , outcome_transform : str = "None" , use_intf = False , ** tkwargs
33+ ) -> ModelListGP :
3234 train_x1 , train_y1 = _get_random_data (
3335 batch_shape = torch .Size (), m = 1 , n = 10 , ** tkwargs
3436 )
37+ train_y1 = torch .exp (train_y1 )
3538 train_x2 , train_y2 = _get_random_data (
3639 batch_shape = torch .Size (), m = 1 , n = 11 , ** tkwargs
3740 )
38- octfs = [Standardize (m = 1 ), Standardize (m = 1 )] if use_octf else [None , None ]
41+ if outcome_transform == "Standardize" :
42+ octfs = [Standardize (m = 1 ), Standardize (m = 1 )]
43+ elif outcome_transform == "Log" :
44+ octfs = [Log (), Standardize (m = 1 )]
45+ elif outcome_transform == "Chained" :
46+ octfs = [
47+ ChainedOutcomeTransform (
48+ chained = ChainedOutcomeTransform (log = Log (), standardize = Standardize (m = 1 ))
49+ ),
50+ Standardize (m = 1 ),
51+ ]
52+ elif outcome_transform == "None" :
53+ octfs = [None , None ]
54+ else :
55+ raise KeyError ( # pragma: no cover
56+ "outcome_transform must be one of 'Standardize', 'Log', 'Chained', or "
57+ "'None'."
58+ )
3959 intfs = [Normalize (d = 1 ), Normalize (d = 1 )] if use_intf else [None , None ]
4060 if fixed_noise :
4161 train_y1_var = 0.1 + 0.1 * torch .rand_like (train_y1 , ** tkwargs )
@@ -73,10 +93,12 @@ def _get_model(fixed_noise=False, use_octf=False, use_intf=False, **tkwargs):
7393
7494class TestModelListGP (BotorchTestCase ):
7595 def _base_test_ModelListGP (
76- self , fixed_noise : bool , dtype , use_octf : bool
96+ self , fixed_noise : bool , dtype , outcome_transform : str
7797 ) -> ModelListGP :
7898 tkwargs = {"device" : self .device , "dtype" : dtype }
79- model = _get_model (fixed_noise = fixed_noise , use_octf = use_octf , ** tkwargs )
99+ model = _get_model (
100+ fixed_noise = fixed_noise , outcome_transform = outcome_transform , ** tkwargs
101+ )
80102 self .assertIsInstance (model , ModelListGP )
81103 self .assertIsInstance (model .likelihood , LikelihoodList )
82104 for m in model .models :
@@ -85,8 +107,12 @@ def _base_test_ModelListGP(
85107 matern_kernel = m .covar_module .base_kernel
86108 self .assertIsInstance (matern_kernel , MaternKernel )
87109 self .assertIsInstance (matern_kernel .lengthscale_prior , GammaPrior )
88- if use_octf :
89- self .assertIsInstance (m .outcome_transform , Standardize )
110+ if outcome_transform != "None" :
111+ self .assertIsInstance (
112+ m .outcome_transform , (Log , Standardize , ChainedOutcomeTransform )
113+ )
114+ else :
115+ assert not hasattr (m , "outcome_transform" )
90116
91117 # test constructing likelihood wrapper
92118 mll = SumMarginalLogLikelihood (model .likelihood , model )
@@ -121,9 +147,19 @@ def _base_test_ModelListGP(
121147 # test posterior
122148 test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
123149 posterior = model .posterior (test_x )
124- self .assertIsInstance (posterior , GPyTorchPosterior )
125- self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
126- if use_octf :
150+ gpytorch_posterior_expected = outcome_transform in ("None" , "Standardize" )
151+ expected_type = (
152+ GPyTorchPosterior if gpytorch_posterior_expected else PosteriorList
153+ )
154+ self .assertIsInstance (posterior , expected_type )
155+ submodel = model .models [0 ]
156+ p0 = submodel .posterior (test_x )
157+ self .assertTrue (torch .allclose (posterior .mean [:, [0 ]], p0 .mean ))
158+ self .assertTrue (torch .allclose (posterior .variance [:, [0 ]], p0 .variance ))
159+
160+ if gpytorch_posterior_expected :
161+ self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
162+ if outcome_transform != "None" :
127163 # ensure un-transformation is applied
128164 submodel = model .models [0 ]
129165 p0 = submodel .posterior (test_x )
@@ -136,8 +172,9 @@ def _base_test_ModelListGP(
136172
137173 # test output_indices
138174 posterior = model .posterior (test_x , output_indices = [0 ], observation_noise = True )
139- self .assertIsInstance (posterior , GPyTorchPosterior )
140- self .assertIsInstance (posterior .distribution , MultivariateNormal )
175+ self .assertIsInstance (posterior , expected_type )
176+ if gpytorch_posterior_expected :
177+ self .assertIsInstance (posterior .distribution , MultivariateNormal )
141178
142179 # test condition_on_observations
143180 f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
@@ -176,39 +213,50 @@ def _base_test_ModelListGP(
176213 X = torch .rand (3 , 1 , ** tkwargs )
177214 weights = torch .tensor ([1 , 2 ], ** tkwargs )
178215 post_tf = ScalarizedPosteriorTransform (weights = weights )
179- posterior_tf = model .posterior (X , posterior_transform = post_tf )
180- self .assertTrue (
181- torch .allclose (
182- posterior_tf .mean ,
183- model .posterior (X ).mean @ weights .unsqueeze (- 1 ),
216+ if gpytorch_posterior_expected :
217+ posterior_tf = model .posterior (X , posterior_transform = post_tf )
218+ self .assertTrue (
219+ torch .allclose (
220+ posterior_tf .mean ,
221+ model .posterior (X ).mean @ weights .unsqueeze (- 1 ),
222+ )
184223 )
185- )
186224
187225 return model
188226
189227 def test_ModelListGP (self ) -> None :
190- for dtype , use_octf in itertools .product (
191- (torch .float , torch .double ), (False , True )
228+ for dtype , outcome_transform in itertools .product (
229+ (torch .float , torch .double ), ("None" , "Standardize" , "Log" , "Chained" )
192230 ):
193231
194232 model = self ._base_test_ModelListGP (
195- fixed_noise = False , dtype = dtype , use_octf = use_octf
233+ fixed_noise = False , dtype = dtype , outcome_transform = outcome_transform
196234 )
197235 tkwargs = {"device" : self .device , "dtype" : dtype }
198236
199237 # test observation_noise
200238 test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
201239 posterior = model .posterior (test_x , observation_noise = True )
202- self .assertIsInstance (posterior , GPyTorchPosterior )
203- self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
240+
241+ gpytorch_posterior_expected = outcome_transform in ("None" , "Standardize" )
242+ expected_type = (
243+ GPyTorchPosterior if gpytorch_posterior_expected else PosteriorList
244+ )
245+ self .assertIsInstance (posterior , expected_type )
246+ if gpytorch_posterior_expected :
247+ self .assertIsInstance (
248+ posterior .distribution , MultitaskMultivariateNormal
249+ )
250+ else :
251+ self .assertIsInstance (posterior .posteriors [0 ], TransformedPosterior )
204252
205253 def test_ModelListGP_fixed_noise (self ) -> None :
206254
207- for dtype , use_octf in itertools .product (
208- (torch .float , torch .double ), (False , True )
255+ for dtype , outcome_transform in itertools .product (
256+ (torch .float , torch .double ), ("None" , "Standardize" )
209257 ):
210258 model = self ._base_test_ModelListGP (
211- fixed_noise = True , dtype = dtype , use_octf = use_octf
259+ fixed_noise = True , dtype = dtype , outcome_transform = outcome_transform
212260 )
213261 tkwargs = {"device" : self .device , "dtype" : dtype }
214262 f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
0 commit comments