From 1a678146886ae7262eb9ea1bae76e91f70b8e01f Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Mon, 17 Nov 2025 11:11:04 -0800 Subject: [PATCH] Structuring arguments in gen_candidates_torch (#3019) Summary: X-link: https://github.com/pytorch/botorch/pull/3019 Structuring the optimizer and stopping_criterion arguments in gen_candidates_torch. Fixes https://github.com/pytorch/botorch/issues/2994. Reviewed By: sdaulton Differential Revision: D82839737 --- botorch/generation/gen.py | 29 +++++++++++++++++++--- test/generation/test_gen.py | 49 +++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/botorch/generation/gen.py b/botorch/generation/gen.py index 7611691d27..a58232aa3e 100644 --- a/botorch/generation/gen.py +++ b/botorch/generation/gen.py @@ -532,7 +532,10 @@ def gen_candidates_torch( optimizer (Optimizer): The pytorch optimizer to use to perform candidate search. options: Options used to control the optimization. Includes - maxiter: Maximum number of iterations + + - optimizer_options: Dict of additional options to pass to the optimizer + (e.g. lr, weight_decay) + - stopping_criterion_options: Dict of options for the stopping criterion. callback: A callback function accepting the current iteration, loss, and gradients as arguments. This function is executed after computing the loss and gradients, but before calling the optimizer. @@ -571,7 +574,6 @@ def gen_candidates_torch( # the 1st order optimizers implemented in this method. # Here, it does not matter whether one combines multiple optimizations into # one or not. - options.pop("max_optimization_problem_aggregation_size", None) _clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds) clamped_candidates = _clamp(initial_conditions) if fixed_features: @@ -580,11 +582,30 @@ def gen_candidates_torch( [i for i in range(clamped_candidates.shape[-1]) if i not in fixed_features], ] clamped_candidates = clamped_candidates.requires_grad_(True) - _optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025)) + + # Extract optimizer-specific options from the options dict + optimizer_options = options.get("optimizer_options", {}).copy() + stopping_criterion_options = options.get("stopping_criterion_options", {}).copy() + + # Backward compatibility: if old 'maxiter' parameter is passed, move it to + # stopping_criterion_options with a deprecation warning + if "maxiter" in options: + warnings.warn( + "Passing 'maxiter' directly in options is deprecated. " + "Please use options['stopping_criterion_options']['maxiter'] instead.", + DeprecationWarning, + stacklevel=2, + ) + # For backward compatibility, pass to stopping_criterion_options + if "maxiter" not in stopping_criterion_options: + stopping_criterion_options["maxiter"] = options["maxiter"] + + optimizer_options.setdefault("lr", 0.025) + _optimizer = optimizer(params=[clamped_candidates], **optimizer_options) i = 0 stop = False - stopping_criterion = ExpMAStoppingCriterion(**options) + stopping_criterion = ExpMAStoppingCriterion(**stopping_criterion_options) while not stop: i += 1 with torch.no_grad(): diff --git a/test/generation/test_gen.py b/test/generation/test_gen.py index dc5961038c..b847934942 100644 --- a/test/generation/test_gen.py +++ b/test/generation/test_gen.py @@ -324,6 +324,55 @@ def test_gen_candidates_torch_timeout_behavior(self): self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws)) self.assertTrue("Optimization timed out" in logs.output[-1]) + def test_gen_candidates_torch_optimizer_with_optimizer_args(self): + """Test that optimizer is created with correct args.""" + self._setUp(double=False) + qEI = qExpectedImprovement(self.model, best_f=self.f_best) + + # Test new structured API + with self.subTest(api="structured"): + # Create a mock optimizer class + mock_optimizer_class = mock.MagicMock() + mock_optimizer_instance = mock.MagicMock() + mock_optimizer_class.return_value = mock_optimizer_instance + + gen_candidates_torch( + initial_conditions=self.initial_conditions, + acquisition_function=qEI, + lower_bounds=0, + upper_bounds=1, + optimizer=mock_optimizer_class, + options={ + "optimizer_options": {"lr": 0.02, "weight_decay": 1e-5}, + "stopping_criterion_options": {"maxiter": 1}, + }, + ) + + # Verify that the optimizer was called with the correct arguments + mock_optimizer_class.assert_called_once() + call_args = mock_optimizer_class.call_args + self.assertIn("params", call_args.kwargs) + self.assertEqual(call_args.kwargs["lr"], 0.02) + self.assertEqual(call_args.kwargs["weight_decay"], 1e-5) + + # Test backward compatibility with old maxiter parameter + with self.subTest(api="backward_compat"): + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always", category=DeprecationWarning) + gen_candidates_torch( + initial_conditions=self.initial_conditions, + acquisition_function=qEI, + lower_bounds=0, + upper_bounds=1, + options={"maxiter": 1}, + ) + # Verify deprecation warning was raised + deprecation_warnings = [ + w for w in ws if issubclass(w.category, DeprecationWarning) + ] + self.assertTrue(len(deprecation_warnings) > 0) + self.assertIn("maxiter", str(deprecation_warnings[0].message)) + def test_gen_candidates_scipy_warns_opt_no_res(self): ckwargs = {"dtype": torch.float, "device": self.device}