diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index a92f4e9ee5..282afb2c3d 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -958,7 +958,10 @@ def _get_dummy_mcmc_samples( return mcmc_samples def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, ) -> None: r"""Custom logic for loading the state dict. @@ -980,7 +983,7 @@ def load_state_dict( ) self.load_mcmc_samples(mcmc_samples=mcmc_samples) # Load the actual samples from the state dict - super().load_state_dict(state_dict=state_dict, strict=strict) + super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP): @@ -1047,7 +1050,10 @@ def median_weight_variance(self) -> Tensor: return weight_variance.median(0).values.squeeze(0) def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, ) -> None: r"""Custom logic for loading the state dict. @@ -1077,4 +1083,4 @@ def load_state_dict( mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs) self.load_mcmc_samples(mcmc_samples=mcmc_samples) # Load the actual samples from the state dict - super().load_state_dict(state_dict=state_dict, strict=strict) + super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 9f105efcfd..33052d8ee2 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -328,6 +328,7 @@ def load_state_dict( state_dict: Mapping[str, Any], strict: bool = True, keep_transforms: bool = True, + assign: bool = False, ) -> None: r"""Load the model state. @@ -337,9 +338,14 @@ def load_state_dict( keep_transforms: A boolean indicating whether to keep the input and outcome transforms. Doing so is useful when loading a model that was trained on a full set of data, and is later loaded with a subset of the data. + assign: When set to ``False``, the properties of the tensors in the current + module are preserved whereas setting it to ``True`` preserves + properties of the Tensors in the state dict. The only + exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter` + for which the value from the module is preserved. Default: ``False``. """ if not keep_transforms: - super().load_state_dict(state_dict, strict) + super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) return should_outcome_transform = ( @@ -368,10 +374,12 @@ def load_state_dict( BotorchWarning, stacklevel=3, ) - super().load_state_dict(state_dict, strict) + super().load_state_dict( + state_dict=state_dict, strict=strict, assign=assign + ) return - super().load_state_dict(state_dict, strict) + super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) if getattr(self, "input_transform", None) is not None: self.input_transform.eval() @@ -763,8 +771,11 @@ def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, + assign: bool = False, ) -> None: - return ModelList.load_state_dict(self, state_dict, strict) + return ModelList.load_state_dict( + self, state_dict=state_dict, strict=strict, assign=assign + ) # pyre-fixme[14]: Inconsistent override in return types def posterior( diff --git a/botorch/models/model.py b/botorch/models/model.py index 8fb3b69eed..b0b3753b2e 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -581,6 +581,7 @@ def load_state_dict( state_dict: Mapping[str, Any], strict: bool = True, keep_transforms: bool = True, + assign: bool = False, ) -> None: """Initialize the fully Bayesian models before loading the state dict.""" for i, m in enumerate(self.models): @@ -589,7 +590,7 @@ def load_state_dict( for k, v in state_dict.items() if k.startswith(f"models.{i}.") } - m.load_state_dict(filtered_dict, strict=strict) + m.load_state_dict(filtered_dict, strict=strict, assign=assign) def fantasize( self, diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index ff92d15ada..bdcb7ec9e6 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -1043,6 +1043,69 @@ def test_load_state_dict_with_transforms(self): ) ) + def test_load_state_dict_assign_parameter(self): + """Test that the assign parameter correctly controls tensor property preservation. + + With assign=False (default): properties of the current model's tensors are preserved. + With assign=True: properties of the state dict's tensors are preserved. + """ + # Create base model with double precision + tkwargs_double = {"device": self.device, "dtype": torch.double} + train_X_double = torch.rand(5, 2, **tkwargs_double) + train_Y_double = torch.sin(train_X_double).sum(dim=1, keepdim=True) + + base_model = SingleTaskGP( + train_X=train_X_double, + train_Y=train_Y_double, + **_get_input_output_transform(d=2, indices=[0, 1], m=1), + ) + state_dict_double = base_model.state_dict() + + # Create a new model with float32 precision (different dtype) + tkwargs_float = {"device": self.device, "dtype": torch.float} + train_X_float = torch.rand(5, 2, **tkwargs_float) + train_Y_float = torch.sin(train_X_float).sum(dim=1, keepdim=True) + + # Test assign=False (default behavior) + model_assign_false = SingleTaskGP( + train_X=train_X_float, + train_Y=train_Y_float, + **_get_input_output_transform(d=2, indices=[0, 1], m=1), + ) + + # Load double precision state dict with assign=False + model_assign_false.load_state_dict( + state_dict_double, keep_transforms=True, assign=False + ) + + # With assign=False, the model should keep its original float32 dtype + self.assertEqual(model_assign_false.train_inputs[0].dtype, torch.float) + + # Test assign=True + model_assign_true = SingleTaskGP( + train_X=train_X_float, + train_Y=train_Y_float, + **_get_input_output_transform(d=2, indices=[0, 1], m=1), + ) + + # Load double precision state dict with assign=True + model_assign_true.load_state_dict( + state_dict_double, keep_transforms=True, assign=True + ) + + # With assign=True, the model should adopt the state dict's double dtype + self.assertEqual(model_assign_true.train_inputs[0].dtype, torch.double) + self.assertEqual( + model_assign_true.train_inputs[0].dtype, + state_dict_double["train_inputs.0"].dtype, + ) + + # Verify the two models have different dtypes + self.assertNotEqual( + model_assign_false.train_inputs[0].dtype, + model_assign_true.train_inputs[0].dtype, + ) + def test_load_state_dict_no_transforms(self): tkwargs = {"device": self.device, "dtype": torch.double}