Skip to content

Commit 3805046

Browse files
ItsMrLinfacebook-github-bot
authored andcommitted
Fix FixedSingleSampleModel dtype/device conversion (#1254)
Summary: Pull Request resolved: #1254 Sometimes `model.train_inputs[0]` is a tuple instead of a tensor. Instead of assuming the structure of the model's class members, will just cast X on the fly in `forward`. It shouldn't cause any additional runtime if device and dtype align. Reviewed By: Balandat Differential Revision: D36976244 fbshipit-source-id: 6bd670e52d155ee5a6215598ce840ec1d6c7ae8c
1 parent 4224cd8 commit 3805046

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

botorch/models/deterministic.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,7 @@ def __init__(self, model: Model, w: Optional[Tensor] = None) -> None:
215215
self.model = model
216216
self._num_outputs = model.num_outputs
217217
self.w = torch.randn(model.num_outputs)
218-
# Check since Model doesn't guarantee a train_inputs attribute
219-
if hasattr(model, "train_inputs"):
220-
self.w = self.w.to(model.train_inputs[0])
221218

222219
def forward(self, X: Tensor) -> Tensor:
223220
post = self.model.posterior(X)
224-
return post.mean + post.variance.sqrt() * self.w
221+
return post.mean + post.variance.sqrt() * self.w.to(X)

test/models/test_deterministic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,7 @@ def test_FixedSingleSampleModel(self):
168168
train_Y_double = torch.rand(2, 2, dtype=torch.double)
169169
model_double = SingleTaskGP(train_X=train_X_double, train_Y=train_Y_double)
170170
fss_model_double = FixedSingleSampleModel(model=model_double)
171-
self.assertTrue(fss_model_double.w.dtype == train_X_double.dtype)
171+
test_X_float = torch.rand(2, 3, dtype=torch.float)
172+
173+
# the following line should execute fine
174+
fss_model_double.posterior(test_X_float)

0 commit comments

Comments
 (0)