Skip to content

Commit a668ccf

Browse files
tskarvoneToni Karvonenmmahsereci
authored
Multilevel Bayesian quadrature (#750)
Co-authored-by: Toni Karvonen <tskarvon@iki.fi> Co-authored-by: Maren Mahsereci <42842079+mmahsereci@users.noreply.github.com> Co-authored-by: Maren Mahsereci <maren.mhsrc@gmail.com>
1 parent a980df3 commit a668ccf

File tree

3 files changed

+289
-6
lines changed

3 files changed

+289
-6
lines changed

src/probnum/quad/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
"""
88

99
from . import integration_measures, kernel_embeddings, solvers
10-
from ._bayesquad import bayesquad, bayesquad_from_data
10+
from ._bayesquad import bayesquad, bayesquad_from_data, multilevel_bayesquad_from_data
1111

1212
# Public classes and functions. Order is reflected in documentation.
1313
__all__ = [
1414
"bayesquad",
1515
"bayesquad_from_data",
16+
"multilevel_bayesquad_from_data",
1617
]
1718

1819
# Set correct module paths. Corrects links and module paths in documentation.
1920
bayesquad.__module__ = "probnum.quad"
2021
bayesquad_from_data.__module__ = "probnum.quad"
22+
multilevel_bayesquad_from_data.__module__ = "probnum.quad"

src/probnum/quad/_bayesquad.py

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def bayesquad(
159159
References
160160
----------
161161
.. [1] Briol, F.-X., et al., Probabilistic integration: A role in statistical
162-
computation?, *Statistical Science 34.1*, 2019, 1-22, 2019
162+
computation?, *Statistical Science 34.1*, 2019, 1-22.
163163
.. [2] Rasmussen, C. E., and Z. Ghahramani, Bayesian Monte Carlo, *Advances in
164164
Neural Information Processing Systems*, 2003, 505-512.
165165
.. [3] Mckay et al., A Comparison of Three Methods for Selecting Values of Input
@@ -168,7 +168,6 @@ def bayesquad(
168168
Examples
169169
--------
170170
>>> import numpy as np
171-
172171
>>> input_dim = 1
173172
>>> domain = (0, 1)
174173
>>> def fun(x):
@@ -299,12 +298,150 @@ def bayesquad_from_data(
299298
return integral_belief, info
300299

301300

301+
def multilevel_bayesquad_from_data(
302+
nodes: Tuple[np.ndarray, ...],
303+
fun_diff_evals: Tuple[np.ndarray, ...],
304+
kernels: Optional[Tuple[Kernel, ...]] = None,
305+
measure: Optional[IntegrationMeasure] = None,
306+
domain: Optional[DomainLike] = None,
307+
options: Optional[dict] = None,
308+
) -> Tuple[Normal, Tuple[BQIterInfo, ...]]:
309+
r"""Infer the value of an integral from given sets of nodes and function
310+
evaluations using a multilevel method.
311+
312+
In multilevel Bayesian quadrature, the integral :math:`\int_\Omega f(x) d \mu(x)`
313+
is (approximately) decomposed as a telescoping sum over :math:`L+1` levels:
314+
315+
.. math:: \int_\Omega f(x) d \mu(x) \approx \int_\Omega f_0(x) d
316+
\mu(x) + \sum_{l=1}^L \int_\Omega [f_l(x) - f_{l-1}(x)] d \mu(x),
317+
318+
where :math:`f_l` is an increasingly accurate but also increasingly expensive
319+
approximation to :math:`f`. It is not necessary that the highest level approximation
320+
:math:`f_L` be equal to :math:`f`.
321+
322+
Bayesian quadrature is subsequently applied to independently infer each of the
323+
:math:`L+1` integrals and the outputs are summed to infer
324+
:math:`\int_\Omega f(x) d \mu(x)`. [1]_
325+
326+
Parameters
327+
----------
328+
nodes
329+
Tuple of length :math:`L+1` containing the locations for each level at which
330+
the functionn evaluations are available as ``fun_diff_evals``. Each element
331+
must be a shape=(n_eval, input_dim) ``np.ndarray``. If a tuple containing only
332+
one element is provided, it is inferred that the same nodes ``nodes[0]`` are
333+
used on every level.
334+
fun_diff_evals
335+
Tuple of length :math:`L+1` containing the evaluations of :math:`f_l - f_{l-1}`
336+
for each level at the nodes provided in ``nodes``. Each element must be a
337+
shape=(n_eval,) ``np.ndarray``. The zeroth element contains the evaluations of
338+
:math:`f_0`.
339+
kernels
340+
Tuple of length :math:`L+1` containing the kernels used for the GP model at each
341+
level. See **Notes** for further details. Defaults to the ``ExpQuad`` kernel for
342+
each level.
343+
measure
344+
The integration measure. Defaults to the Lebesgue measure.
345+
domain
346+
The integration domain. Contains lower and upper bound as scalar or
347+
``np.ndarray``. Obsolete if ``measure`` is given.
348+
options
349+
A dictionary with the following optional solver settings
350+
351+
scale_estimation : Optional[str]
352+
Estimation method to use to compute the scale parameter. Used
353+
independently on each level. Defaults to 'mle'. Options are
354+
355+
============================== =======
356+
Maximum likelihood estimation ``mle``
357+
============================== =======
358+
359+
jitter : Optional[FloatLike]
360+
Non-negative jitter to numerically stabilise kernel matrix
361+
inversion. Same jitter is used on each level. Defaults to 1e-8.
362+
363+
Returns
364+
-------
365+
integral :
366+
The integral belief subject to the provided measure or domain.
367+
infos :
368+
Information on the performance of the method for each level.
369+
370+
Raises
371+
------
372+
ValueError
373+
If ``nodes``, ``fun_diff_evals`` or ``kernels`` have different lengths.
374+
375+
Warns
376+
-----
377+
UserWarning
378+
When ``domain`` is given but not used.
379+
380+
Notes
381+
-----
382+
The tuple of kernels provided by the ``kernels`` parameter must contain distinct
383+
kernel instances, i.e., ``kernels[i] is kernel[j]`` must return ``False`` for any
384+
:math:`i\neq j`.
385+
386+
References
387+
----------
388+
.. [1] Li, K., et al., Multilevel Bayesian quadrature, AISTATS, 2023.
389+
390+
Examples
391+
--------
392+
>>> import numpy as np
393+
>>> input_dim = 1
394+
>>> domain = (0, 1)
395+
>>> n_level = 6
396+
>>> def fun(x, l):
397+
... return x.reshape(-1, ) / (l + 1.0)
398+
>>> nodes = ()
399+
>>> fun_diff_evals = ()
400+
>>> for l in range(n_level):
401+
... n_l = 2*l + 1
402+
... nodes += (np.reshape(np.linspace(0, 1, n_l), (n_l, input_dim)),)
403+
... fun_diff_evals += (np.reshape(fun(nodes[l], l), (n_l,)),)
404+
>>> F, infos = multilevel_bayesquad_from_data(nodes=nodes,
405+
... fun_diff_evals=fun_diff_evals,
406+
... domain=domain)
407+
>>> print(np.round(F.mean, 4))
408+
0.7252
409+
"""
410+
411+
n_level = len(fun_diff_evals)
412+
if kernels is None:
413+
kernels = n_level * (None,)
414+
if len(nodes) == 1:
415+
nodes = n_level * (nodes[0],)
416+
if not len(nodes) == len(fun_diff_evals) == len(kernels):
417+
raise ValueError(
418+
f"You must provide an equal number of kernels ({(len(kernels))}), "
419+
f"vectors of function evaluations ({len(fun_diff_evals)}) "
420+
f"and sets of nodes ({len(nodes)})."
421+
)
422+
423+
integer_belief = Normal(mean=0.0, cov=0.0)
424+
infos = ()
425+
for level in range(n_level):
426+
integer_belief_l, info_l = bayesquad_from_data(
427+
nodes=nodes[level],
428+
fun_evals=fun_diff_evals[level],
429+
kernel=kernels[level],
430+
measure=measure,
431+
domain=domain,
432+
options=options,
433+
)
434+
integer_belief += integer_belief_l
435+
infos += (info_l,)
436+
437+
return integer_belief, infos
438+
439+
302440
def _check_domain_measure_compatibility(
303441
input_dim: IntLike,
304442
domain: Optional[DomainLike],
305443
measure: Optional[IntegrationMeasure],
306444
) -> Tuple[int, Optional[DomainType], IntegrationMeasure]:
307-
308445
input_dim = int(input_dim)
309446

310447
# Neither domain nor measure given

tests/test_quad/test_bayesquad/test_bq.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Test cases for Bayesian quadrature."""
2+
import copy
23

34
import numpy as np
45
import pytest
56
from scipy.integrate import quad as scipyquad
67

7-
from probnum.quad import bayesquad, bayesquad_from_data
8-
from probnum.quad.integration_measures import LebesgueMeasure
8+
from probnum.quad import bayesquad, bayesquad_from_data, multilevel_bayesquad_from_data
9+
from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure
910
from probnum.quad.kernel_embeddings import KernelEmbedding
1011
from probnum.randvars import Normal
1112

@@ -219,3 +220,146 @@ def test_zero_function_gives_zero_variance_with_mle(rng):
219220
)
220221
assert bq_integral1.var == 0.0
221222
assert bq_integral2.var == 0.0
223+
224+
225+
def test_multilevel_bayesquad_from_data_output_types_and_shapes(kernel, measure, rng):
226+
"""Test correct output for different inputs to multilevel BQ."""
227+
228+
# full set of nodes
229+
ns_1 = (3, 7, 2)
230+
n_level_1 = len(ns_1)
231+
fun_diff_evals_1 = tuple(np.zeros(ns_1[l]) for l in range(n_level_1))
232+
nodes_full = tuple(measure.sample((ns_1[l]), rng=rng) for l in range(n_level_1))
233+
234+
# i) default kernel
235+
F, infos = multilevel_bayesquad_from_data(
236+
nodes=nodes_full,
237+
fun_diff_evals=fun_diff_evals_1,
238+
measure=measure,
239+
)
240+
assert isinstance(F, Normal)
241+
assert len(infos) == n_level_1
242+
243+
# ii) full kernel list
244+
kernels_full_1 = tuple(copy.deepcopy(kernel) for _ in range(n_level_1))
245+
F, infos = multilevel_bayesquad_from_data(
246+
nodes=nodes_full,
247+
fun_diff_evals=fun_diff_evals_1,
248+
kernels=kernels_full_1,
249+
measure=measure,
250+
)
251+
assert isinstance(F, Normal)
252+
assert len(infos) == n_level_1
253+
254+
# one set of nodes
255+
n_level_2 = 3
256+
ns_2 = n_level_2 * (7,)
257+
fun_diff_evals_2 = tuple(np.zeros(ns_2[l]) for l in range(n_level_2))
258+
nodes_1 = (measure.sample(n_sample=ns_2[0], rng=rng),)
259+
260+
# i) default kernel
261+
F, infos = multilevel_bayesquad_from_data(
262+
nodes=nodes_1,
263+
fun_diff_evals=fun_diff_evals_2,
264+
measure=measure,
265+
)
266+
assert isinstance(F, Normal)
267+
assert len(infos) == n_level_2
268+
269+
# ii) full kernel list
270+
kernels_full_2 = tuple(copy.deepcopy(kernel) for _ in range(n_level_2))
271+
F, infos = multilevel_bayesquad_from_data(
272+
nodes=nodes_1,
273+
fun_diff_evals=fun_diff_evals_2,
274+
kernels=kernels_full_2,
275+
measure=measure,
276+
)
277+
assert isinstance(F, Normal)
278+
assert len(infos) == n_level_2
279+
280+
281+
def test_multilevel_bayesquad_from_data_wrong_inputs(kernel, measure, rng):
282+
"""Tests that wrong number inputs to multilevel BQ throw errors."""
283+
ns = (3, 7, 11)
284+
n_level = len(ns)
285+
fun_diff_evals = tuple(np.zeros(ns[l]) for l in range(n_level))
286+
287+
# number of nodes does not match the number of fun evals
288+
wrong_n_nodes = 2
289+
nodes_2 = tuple(measure.sample((ns[l]), rng=rng) for l in range(wrong_n_nodes))
290+
with pytest.raises(ValueError):
291+
multilevel_bayesquad_from_data(
292+
nodes=nodes_2,
293+
fun_diff_evals=fun_diff_evals,
294+
measure=measure,
295+
)
296+
297+
# number of kernels does not match number of fun evals
298+
wrong_n_kernels = 2
299+
kernels = tuple(copy.deepcopy(kernel) for _ in range(wrong_n_kernels))
300+
nodes_1 = (measure.sample(n_sample=ns[0], rng=rng),)
301+
with pytest.raises(ValueError):
302+
multilevel_bayesquad_from_data(
303+
nodes=nodes_1,
304+
fun_diff_evals=fun_diff_evals,
305+
kernels=kernels,
306+
measure=measure,
307+
)
308+
309+
310+
def test_multilevel_bayesquad_from_data_equals_bq_with_trivial_data_1d():
311+
"""Test that multilevel BQ equals BQ when all but one level are given non-zero
312+
function evaluations for 1D data."""
313+
n_level = 5
314+
domain = (0, 3.3)
315+
nodes = tuple(np.linspace(0, 1, 2 * l + 1)[:, None] for l in range(n_level))
316+
for i in range(n_level):
317+
jitter = 1e-5 * (i + 1.0)
318+
fun_evals = nodes[i][:, 0] ** (2 + 0.3 * i) + 1.2
319+
fun_diff_evals = [np.zeros(shape=(len(xs),)) for xs in nodes]
320+
fun_diff_evals[i] = fun_evals
321+
mlbq_integral, _ = multilevel_bayesquad_from_data(
322+
nodes=nodes,
323+
fun_diff_evals=tuple(fun_diff_evals),
324+
domain=domain,
325+
options=dict(jitter=jitter),
326+
)
327+
bq_integral, _ = bayesquad_from_data(
328+
nodes=nodes[i],
329+
fun_evals=fun_evals,
330+
domain=domain,
331+
options=dict(jitter=jitter),
332+
)
333+
assert mlbq_integral.mean == bq_integral.mean
334+
assert mlbq_integral.cov == bq_integral.cov
335+
336+
337+
def test_multilevel_bayesquad_from_data_equals_bq_with_trivial_data_2d():
338+
"""Test that multilevel BQ equals BQ when all but one level are given non-zero
339+
function evaluations for 2D data."""
340+
input_dim = 2
341+
n_level = 5
342+
measure = GaussianMeasure(np.full((input_dim,), 0.2), cov=0.6 * np.eye(input_dim))
343+
_gh = gauss_hermite_tensor
344+
nodes = tuple(
345+
_gh(l + 1, input_dim, measure.mean, measure.cov)[0] for l in range(n_level)
346+
)
347+
for i in range(n_level):
348+
jitter = 1e-5 * (i + 1.0)
349+
fun_evals = np.sin(nodes[i][:, 0] * i) + (i + 1.0) * np.cos(nodes[i][:, 1])
350+
fun_diff_evals = [np.zeros(shape=(len(xs),)) for xs in nodes]
351+
fun_diff_evals[i] = fun_evals
352+
mlbq_integral, _ = multilevel_bayesquad_from_data(
353+
nodes=nodes,
354+
fun_diff_evals=tuple(fun_diff_evals),
355+
measure=measure,
356+
options=dict(jitter=jitter),
357+
)
358+
bq_integral, _ = bayesquad_from_data(
359+
nodes=nodes[i],
360+
fun_evals=fun_evals,
361+
measure=measure,
362+
options=dict(jitter=jitter),
363+
)
364+
assert mlbq_integral.mean == bq_integral.mean
365+
assert mlbq_integral.cov == bq_integral.cov

0 commit comments

Comments
 (0)