Skip to content

Commit 3f89c29

Browse files
esantorellafacebook-github-bot
authored andcommitted
Raise an error if Standardize outcome transform's untransform_posterior is used without first calling the transform on outcomes (#1569)
Summary: Pull Request resolved: #1569 Say someone does the following bad thing: > tf = Standardize(m=1) > tf.untransform_posterior(posterior) Old behavior: - means and standard deviations are initialized in `Standarize.__init__` with a tensor of zeros, with 'device' not set - With a posterior on the CPU, the posterior would be nonsensically untransformed with means and standard deviations of zero - With a posterior on the GPU, this would cause an error about tensors on different devices, e.g. https://www.internalfb.com/diff/D42019721?dst_version_fbid=1618282175279712&selected_signal=dGVzdF9pZDo1NjI5NTAwMjcwNTY2NTk%3D&selected_signal_verification_phase=1 New behavior: - means and standard deviations are initialized as None - An informative error is raised Reviewed By: saitcakmak, Balandat Differential Revision: D42039100 fbshipit-source-id: f9585e2c32715216781651ebfcd8878d2e7e6971
1 parent 76062a6 commit 3f89c29

File tree

2 files changed

+62
-10
lines changed

2 files changed

+62
-10
lines changed

botorch/models/transforms/outcome.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,9 @@ def __init__(
223223
standardization (if lower, only de-mean the data).
224224
"""
225225
super().__init__()
226-
self.register_buffer("means", torch.zeros(*batch_shape, 1, m))
227-
self.register_buffer("stdvs", torch.zeros(*batch_shape, 1, m))
228-
self.register_buffer("_stdvs_sq", torch.zeros(*batch_shape, 1, m))
226+
self.register_buffer("means", None)
227+
self.register_buffer("stdvs", None)
228+
self.register_buffer("_stdvs_sq", None)
229229
self._outputs = normalize_indices(outputs, d=m)
230230
self._m = m
231231
self._batch_shape = batch_shape
@@ -296,9 +296,10 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform:
296296
batch_shape=self._batch_shape,
297297
min_stdv=self._min_stdv,
298298
)
299-
new_tf.means = self.means[..., nlzd_idcs]
300-
new_tf.stdvs = self.stdvs[..., nlzd_idcs]
301-
new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs]
299+
if self.means is not None:
300+
new_tf.means = self.means[..., nlzd_idcs]
301+
new_tf.stdvs = self.stdvs[..., nlzd_idcs]
302+
new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs]
302303
if not self.training:
303304
new_tf.eval()
304305
return new_tf
@@ -319,6 +320,13 @@ def untransform(
319320
- The un-standardized outcome observations.
320321
- The un-standardized observation noise (if applicable).
321322
"""
323+
if self.means is None:
324+
raise RuntimeError(
325+
"`Standardize` transforms must be called on outcome data "
326+
"(e.g. `transform(Y)`) before calling `untransform`, since "
327+
"means and standard deviations need to be computed."
328+
)
329+
322330
Y_utf = self.means + self.stdvs * Y
323331
Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None
324332
return Y_utf, Yvar_utf
@@ -341,13 +349,20 @@ def untransform_posterior(
341349
"Standardize does not yet support output selection for "
342350
"untransform_posterior"
343351
)
352+
if self.means is None:
353+
raise RuntimeError(
354+
"`Standardize` transforms must be called on outcome data "
355+
"(e.g. `transform(Y)`) before calling `untransform_posterior`, since "
356+
"means and standard deviations need to be computed."
357+
)
344358
is_mtgp_posterior = False
345359
if type(posterior) is GPyTorchPosterior:
346360
is_mtgp_posterior = posterior._is_mt
347361
if not self._m == posterior._extended_shape()[-1] and not is_mtgp_posterior:
348362
raise RuntimeError(
349-
"Incompatible output dimensions encountered for transform "
350-
f"{self._m} and posterior {posterior._extended_shape()[-1]}."
363+
"Incompatible output dimensions encountered. Transform has output "
364+
f"dimension {self._m} and posterior has "
365+
f"{posterior._extended_shape()[-1]}."
351366
)
352367

353368
if type(posterior) is not GPyTorchPosterior:

test/models/transforms/test_outcome.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,36 @@ def test_abstract_base_outcome_transform(self):
6363
with self.assertRaises(NotImplementedError):
6464
oct.untransform_posterior(None)
6565

66+
def test_standardize_raises_when_mean_not_set(self) -> None:
67+
posterior = _get_test_posterior(
68+
shape=torch.Size([1, 1]), device=self.device, dtype=torch.float64
69+
)
70+
for transform in [
71+
Standardize(m=1),
72+
ChainedOutcomeTransform(
73+
chained=ChainedOutcomeTransform(stand=Standardize(m=1))
74+
),
75+
]:
76+
with self.assertRaises(
77+
RuntimeError,
78+
msg="`Standardize` transforms must be called on outcome data "
79+
"(e.g. `transform(Y)`) before calling `untransform_posterior`, since "
80+
"means and standard deviations need to be computed.",
81+
):
82+
transform.untransform_posterior(posterior)
83+
84+
new_tf = transform.subset_output([0])
85+
assert isinstance(new_tf, type(transform))
86+
87+
y = torch.arange(3, device=self.device, dtype=torch.float64)
88+
with self.assertRaises(
89+
RuntimeError,
90+
msg="`Standardize` transforms must be called on outcome data "
91+
"(e.g. `transform(Y)`) before calling `untransform`, since "
92+
"means and standard deviations need to be computed.",
93+
):
94+
transform.untransform(y)
95+
6696
def test_standardize(self):
6797
# test error on incompatible dim
6898
tf = Standardize(m=1)
@@ -208,8 +238,15 @@ def test_standardize(self):
208238

209239
# test error on incompatible output dimension
210240
# TODO: add a unit test for MTGP posterior once #840 goes in
211-
tf_big = Standardize(m=4).eval()
212-
with self.assertRaises(RuntimeError):
241+
tf_big = Standardize(m=4)
242+
Y = torch.arange(4, device=self.device, dtype=dtype).reshape((1, 4))
243+
tf_big(Y)
244+
with self.assertRaises(
245+
RuntimeError,
246+
msg="Incompatible output dimensions encountered. Transform has output "
247+
f"dimension {tf._m} and posterior has "
248+
f"{posterior._extended_shape()[-1]}.",
249+
):
213250
tf_big.untransform_posterior(posterior2)
214251

215252
# test transforming a subset of outcomes

0 commit comments

Comments
 (0)