55import pytensor .tensor as pt
66import xarray
77
8- from better_optimize import minimize
8+ from better_optimize import basinhopping , minimize
99from better_optimize .constants import minimize_method
1010from pymc import DictToArrayBijection , Model , join_nonshared_inputs
1111from pymc .backends .arviz import (
1212 PointFunc ,
1313 apply_function_over_dataset ,
1414 coords_and_dims_for_inferencedata ,
1515)
16+ from pymc .blocking import RaveledVars
1617from pymc .util import RandomSeed , get_default_varnames
1718from pytensor .tensor .variable import TensorVariable
1819
20+ from pymc_extras .inference .laplace_approx .idata import (
21+ add_data_to_inference_data ,
22+ add_optimizer_result_to_inference_data ,
23+ )
1924from pymc_extras .inference .laplace_approx .laplace import unstack_laplace_draws
2025from pymc_extras .inference .laplace_approx .scipy_interface import (
21- _compile_functions_for_scipy_optimize ,
26+ scipy_optimize_funcs_from_loss ,
27+ set_optimizer_function_defaults ,
2228)
2329
2430
2531def fit_dadvi (
2632 model : Model | None = None ,
2733 n_fixed_draws : int = 30 ,
28- random_seed : RandomSeed = None ,
2934 n_draws : int = 1000 ,
30- keep_untransformed : bool = False ,
35+ include_transformed : bool = False ,
3136 optimizer_method : minimize_method = "trust-ncg" ,
32- use_grad : bool = True ,
33- use_hessp : bool = True ,
34- use_hess : bool = False ,
35- ** minimize_kwargs ,
37+ use_grad : bool | None = None ,
38+ use_hessp : bool | None = None ,
39+ use_hess : bool | None = None ,
40+ gradient_backend : str = "pytensor" ,
41+ compile_kwargs : dict | None = None ,
42+ random_seed : RandomSeed = None ,
43+ progressbar : bool = True ,
44+ ** optimizer_kwargs ,
3645) -> az .InferenceData :
3746 """
38- Does inference using deterministic ADVI (automatic differentiation
39- variational inference), DADVI for short.
47+ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
4048
41- For full details see the paper cited in the references:
42- https://www.jmlr.org/papers/v25/23-1015.html
49+ For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
4350
4451 Parameters
4552 ----------
4653 model : pm.Model
4754 The PyMC model to be fit. If None, the current model context is used.
4855
4956 n_fixed_draws : int
50- The number of fixed draws to use for the optimisation. More
51- draws will result in more accurate estimates, but also
52- increase inference time. Usually, the default of 30 is a good
53- tradeoff.between speed and accuracy.
57+ The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
58+ also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
5459
5560 random_seed: int
56- The random seed to use for the fixed draws. Running the optimisation
57- twice with the same seed should arrive at the same result.
61+ The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
62+ the same result.
5863
5964 n_draws: int
6065 The number of draws to return from the variational approximation.
6166
62- keep_untransformed : bool
63- Whether or not to keep the unconstrained variables (such as
64- logs of positive-constrained parameters) in the output.
67+ include_transformed : bool
68+ Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
69+ output.
6570
6671 optimizer_method: str
67- Which optimization method to use. The function calls
68- ``scipy.optimize.minimize``, so any of the methods there can
69- be used. The default is trust-ncg, which uses second-order
70- information and is generally very reliable. Other methods such
71- as L-BFGS-B might be faster but potentially more brittle and
72- may not converge exactly to the optimum.
73-
74- minimize_kwargs:
75- Additional keyword arguments to pass to the
76- ``scipy.optimize.minimize`` function. See the documentation of
77- that function for details.
72+ Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
73+ can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
74+ Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
75+ the optimum.
7876
79- use_grad:
80- If True, pass the gradient function to
81- `scipy.optimize.minimize` (where it is referred to as `jac`).
77+ gradient_backend: str
78+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
8279
83- use_hessp:
80+ compile_kwargs: dict, optional
81+ Additional keyword arguments to pass to `pytensor.function`
82+
83+ use_grad: bool, optional
84+ If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
85+
86+ use_hessp: bool, optional
8487 If True, pass the hessian vector product to `scipy.optimize.minimize`.
8588
86- use_hess:
87- If True, pass the hessian to `scipy.optimize.minimize`. Note that
88- this is generally not recommended since its computation can be slow
89- and memory-intensive if there are many parameters.
89+ use_hess: bool, optional
90+ If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
91+ computation can be slow and memory-intensive if there are many parameters.
92+
93+ progressbar: bool
94+ Whether or not to show a progress bar during optimization. Default is True.
95+
96+ optimizer_kwargs:
97+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
98+ that function for details.
9099
91100 Returns
92101 -------
@@ -95,16 +104,25 @@ def fit_dadvi(
95104
96105 References
97106 ----------
98- Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99- Variational Inference with a Deterministic Objective: Faster, More
100- Accurate, and Even More Black Box. Journal of Machine Learning
101- Research, 25(18), 1–39.
107+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
108+ Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
102109 """
103110
104111 model = pymc .modelcontext (model ) if model is None else model
112+ do_basinhopping = optimizer_method == "basinhopping"
113+ minimizer_kwargs = optimizer_kwargs .pop ("minimizer_kwargs" , {})
114+
115+ if do_basinhopping :
116+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
117+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
118+ # if one isn't provided.
119+
120+ optimizer_method = minimizer_kwargs .pop ("method" , "L-BFGS-B" )
121+ minimizer_kwargs ["method" ] = optimizer_method
105122
106123 initial_point_dict = model .initial_point ()
107- n_params = DictToArrayBijection .map (initial_point_dict ).data .shape [0 ]
124+ initial_point = DictToArrayBijection .map (initial_point_dict )
125+ n_params = initial_point .data .shape [0 ]
108126
109127 var_params , objective = create_dadvi_graph (
110128 model ,
@@ -113,31 +131,65 @@ def fit_dadvi(
113131 n_params = n_params ,
114132 )
115133
116- f_fused , f_hessp = _compile_functions_for_scipy_optimize (
117- objective ,
118- [var_params ],
119- compute_grad = use_grad ,
120- compute_hessp = use_hessp ,
121- compute_hess = use_hess ,
134+ use_grad , use_hess , use_hessp = set_optimizer_function_defaults (
135+ optimizer_method , use_grad , use_hess , use_hessp
122136 )
123137
124- derivative_kwargs = {}
125-
126- if use_grad :
127- derivative_kwargs ["jac" ] = True
128- if use_hessp :
129- derivative_kwargs ["hessp" ] = f_hessp
130- if use_hess :
131- derivative_kwargs ["hess" ] = True
138+ f_fused , f_hessp = scipy_optimize_funcs_from_loss (
139+ loss = objective ,
140+ inputs = [var_params ],
141+ initial_point_dict = None ,
142+ use_grad = use_grad ,
143+ use_hessp = use_hessp ,
144+ use_hess = use_hess ,
145+ gradient_backend = gradient_backend ,
146+ compile_kwargs = compile_kwargs ,
147+ inputs_are_flat = True ,
148+ )
132149
133- result = minimize (
134- f_fused ,
135- np .zeros (2 * n_params ),
136- method = optimizer_method ,
137- ** derivative_kwargs ,
138- ** minimize_kwargs ,
150+ dadvi_initial_point = {
151+ f"{ var_name } _mu" : np .zeros_like (value ).ravel ()
152+ for var_name , value in initial_point_dict .items ()
153+ }
154+ dadvi_initial_point .update (
155+ {
156+ f"{ var_name } _sigma__log" : np .zeros_like (value ).ravel ()
157+ for var_name , value in initial_point_dict .items ()
158+ }
139159 )
140160
161+ dadvi_initial_point = DictToArrayBijection .map (dadvi_initial_point )
162+ args = optimizer_kwargs .pop ("args" , ())
163+
164+ if do_basinhopping :
165+ if "args" not in minimizer_kwargs :
166+ minimizer_kwargs ["args" ] = args
167+ if "hessp" not in minimizer_kwargs :
168+ minimizer_kwargs ["hessp" ] = f_hessp
169+ if "method" not in minimizer_kwargs :
170+ minimizer_kwargs ["method" ] = optimizer_method
171+
172+ result = basinhopping (
173+ func = f_fused ,
174+ x0 = dadvi_initial_point .data ,
175+ progressbar = progressbar ,
176+ minimizer_kwargs = minimizer_kwargs ,
177+ ** optimizer_kwargs ,
178+ )
179+
180+ else :
181+ result = minimize (
182+ f = f_fused ,
183+ x0 = dadvi_initial_point .data ,
184+ args = args ,
185+ method = optimizer_method ,
186+ hessp = f_hessp ,
187+ progressbar = progressbar ,
188+ ** optimizer_kwargs ,
189+ )
190+
191+ raveled_optimized = RaveledVars (result .x , dadvi_initial_point .point_map_info )
192+
141193 opt_var_params = result .x
142194 opt_means , opt_log_sds = np .split (opt_var_params , 2 )
143195
@@ -148,9 +200,29 @@ def fit_dadvi(
148200 draws = opt_means + draws_raw * np .exp (opt_log_sds )
149201 draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
150202
151- transformed_draws = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
203+ idata = dadvi_result_to_idata (
204+ draws_arviz , model , include_transformed = include_transformed , progressbar = progressbar
205+ )
152206
153- return transformed_draws
207+ var_name_to_model_var = {f"{ var_name } _mu" : var_name for var_name in initial_point_dict .keys ()}
208+ var_name_to_model_var .update (
209+ {f"{ var_name } _sigma__log" : var_name for var_name in initial_point_dict .keys ()}
210+ )
211+
212+ idata = add_optimizer_result_to_inference_data (
213+ idata = idata ,
214+ result = result ,
215+ method = optimizer_method ,
216+ mu = raveled_optimized ,
217+ model = model ,
218+ var_name_to_model_var = var_name_to_model_var ,
219+ )
220+
221+ idata = add_data_to_inference_data (
222+ idata = idata , progressbar = False , model = model , compile_kwargs = compile_kwargs
223+ )
224+
225+ return idata
154226
155227
156228def create_dadvi_graph (
@@ -213,10 +285,11 @@ def create_dadvi_graph(
213285 return var_params , objective
214286
215287
216- def transform_draws (
288+ def dadvi_result_to_idata (
217289 unstacked_draws : xarray .Dataset ,
218290 model : Model ,
219- keep_untransformed : bool = False ,
291+ include_transformed : bool = False ,
292+ progressbar : bool = True ,
220293):
221294 """
222295 Transforms the unconstrained draws back into the constrained space.
@@ -232,9 +305,12 @@ def transform_draws(
232305 n_draws: int
233306 The number of draws to return from the variational approximation.
234307
235- keep_untransformed : bool
308+ include_transformed : bool
236309 Whether or not to keep the unconstrained variables in the output.
237310
311+ progressbar: bool
312+ Whether or not to show a progress bar during the transformation. Default is True.
313+
238314 Returns
239315 -------
240316 :class:`~arviz.InferenceData`
@@ -243,7 +319,7 @@ def transform_draws(
243319
244320 filtered_var_names = model .unobserved_value_vars
245321 vars_to_sample = list (
246- get_default_varnames (filtered_var_names , include_transformed = keep_untransformed )
322+ get_default_varnames (filtered_var_names , include_transformed = include_transformed )
247323 )
248324 fn = pytensor .function (model .value_vars , vars_to_sample )
249325 point_func = PointFunc (fn )
@@ -256,6 +332,20 @@ def transform_draws(
256332 output_var_names = [x .name for x in vars_to_sample ],
257333 coords = coords ,
258334 dims = dims ,
335+ progressbar = progressbar ,
259336 )
260337
261- return transformed_result
338+ constrained_names = [
339+ x .name for x in get_default_varnames (model .unobserved_value_vars , include_transformed = False )
340+ ]
341+ all_varnames = [
342+ x .name for x in get_default_varnames (model .unobserved_value_vars , include_transformed = True )
343+ ]
344+ unconstrained_names = sorted (set (all_varnames ) - set (constrained_names ))
345+
346+ idata = az .InferenceData (posterior = transformed_result [constrained_names ])
347+
348+ if unconstrained_names and include_transformed :
349+ idata ["unconstrained_posterior" ] = transformed_result [unconstrained_names ]
350+
351+ return idata
0 commit comments