99import pytensor
1010import pytensor .tensor as pt
1111
12- from better_optimize import minimize
12+ from better_optimize import basinhopping , minimize
1313from better_optimize .constants import MINIMIZE_MODE_KWARGS , minimize_method
1414from pymc .blocking import DictToArrayBijection , RaveledVars
1515from pymc .initial_point import make_initial_point_fn
@@ -335,7 +335,7 @@ def scipy_optimize_funcs_from_loss(
335335
336336
337337def find_MAP (
338- method : minimize_method ,
338+ method : minimize_method | Literal [ "basinhopping" ] ,
339339 * ,
340340 model : pm .Model | None = None ,
341341 use_grad : bool | None = None ,
@@ -352,14 +352,17 @@ def find_MAP(
352352 ** optimizer_kwargs ,
353353) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], OptimizeResult ]:
354354 """
355- Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize .
355+ Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize .
356356
357357 Parameters
358358 ----------
359359 model : pm.Model
360360 The PyMC model to be fit. If None, the current model context is used.
361361 method : str
362- The optimization method to use. See scipy.optimize.minimize documentation for details.
362+ The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
363+ trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
364+
365+ See scipy.optimize.minimize documentation for details.
363366 use_grad : bool | None, optional
364367 Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
365368 the ``method``.
@@ -387,7 +390,9 @@ def find_MAP(
387390 compile_kwargs: dict, optional
388391 Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
389392 **optimizer_kwargs
390- Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
393+ Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
394+ ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
395+ ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
391396
392397 Returns
393398 -------
@@ -413,6 +418,18 @@ def find_MAP(
413418 initial_params = DictToArrayBijection .map (
414419 {var_name : value for var_name , value in start_dict .items () if var_name in vars_dict }
415420 )
421+
422+ do_basinhopping = method == "basinhopping"
423+ minimizer_kwargs = optimizer_kwargs .pop ("minimizer_kwargs" , {})
424+
425+ if do_basinhopping :
426+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
427+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
428+ # if one isn't provided.
429+
430+ method = minimizer_kwargs .pop ("method" , "L-BFGS-B" )
431+ minimizer_kwargs ["method" ] = method
432+
416433 use_grad , use_hess , use_hessp = set_optimizer_function_defaults (
417434 method , use_grad , use_hess , use_hessp
418435 )
@@ -431,17 +448,37 @@ def find_MAP(
431448 args = optimizer_kwargs .pop ("args" , None )
432449
433450 # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
434- # if so. That is why it is not set here, regardless of user settings.
435- optimizer_result = minimize (
436- f = f_logp ,
437- x0 = cast (np .ndarray [float ], initial_params .data ),
438- args = args ,
439- hess = f_hess ,
440- hessp = f_hessp ,
441- progressbar = progressbar ,
442- method = method ,
443- ** optimizer_kwargs ,
444- )
451+ # if so. That is why the jac argument is not passed here in either branch.
452+
453+ if do_basinhopping :
454+ if "args" not in minimizer_kwargs :
455+ minimizer_kwargs ["args" ] = args
456+ if "hess" not in minimizer_kwargs :
457+ minimizer_kwargs ["hess" ] = f_hess
458+ if "hessp" not in minimizer_kwargs :
459+ minimizer_kwargs ["hessp" ] = f_hessp
460+ if "method" not in minimizer_kwargs :
461+ minimizer_kwargs ["method" ] = method
462+
463+ optimizer_result = basinhopping (
464+ func = f_logp ,
465+ x0 = cast (np .ndarray [float ], initial_params .data ),
466+ progressbar = progressbar ,
467+ minimizer_kwargs = minimizer_kwargs ,
468+ ** optimizer_kwargs ,
469+ )
470+
471+ else :
472+ optimizer_result = minimize (
473+ f = f_logp ,
474+ x0 = cast (np .ndarray [float ], initial_params .data ),
475+ args = args ,
476+ hess = f_hess ,
477+ hessp = f_hessp ,
478+ progressbar = progressbar ,
479+ method = method ,
480+ ** optimizer_kwargs ,
481+ )
445482
446483 raveled_optimized = RaveledVars (optimizer_result .x , initial_params .point_map_info )
447484 unobserved_vars = get_default_varnames (model .unobserved_value_vars , include_transformed )
0 commit comments