Skip to content

Commit 53a4302

Browse files
authored
Introducing quad acquisition functions (#749)
1 parent 90b5dd3 commit 53a4302

18 files changed

+495
-41
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Acquisition Functions
2+
---------------------
3+
.. automodapi:: probnum.quad.solvers.acquisition_functions
4+
:no-heading:
5+
:headings: "*"

docs/source/api/quad/solvers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ probnum.quad.solvers
1515

1616
solvers.policies
1717

18+
.. toctree::
19+
:hidden:
20+
21+
solvers.acquisition_functions
22+
1823
.. toctree::
1924
:hidden:
2025

src/probnum/quad/_bayesquad.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
methods return a random variable, specifying the belief about the true value of the
77
integral.
88
"""
9+
910
from __future__ import annotations
1011

1112
from typing import Callable, Optional, Tuple
@@ -70,10 +71,11 @@ def bayesquad(
7071
policy
7172
Type of acquisition strategy to use. Defaults to 'bmc'. Options are
7273
73-
========================== =======
74-
Bayesian Monte Carlo [2]_ ``bmc``
75-
van Der Corput points ``vdc``
76-
========================== =======
74+
============================================ ===========
75+
Bayesian Monte Carlo [2]_ ``bmc``
76+
Van Der Corput points ``vdc``
77+
Uncertainty Sampling with random candidates ``us_rand``
78+
============================================ ===========
7779
7880
initial_design
7981
The type of initial design to use. If ``None`` is given, no initial design is
@@ -112,6 +114,9 @@ def bayesquad(
112114
num_initial_design_nodes : Optional[IntLike]
113115
The number of nodes created by the initial design. Defaults to
114116
``input_dim * 5`` if an initial design is given.
117+
us_rand_num_candidates : Optional[IntLike]
118+
The number of candidate nodes used by the policy 'us_rand'. Defaults
119+
to 1e2.
115120
116121
Returns
117122
-------

src/probnum/quad/solvers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
"""Bayesian quadrature methods and their components."""
22

3-
from . import belief_updates, initial_designs, policies, stopping_criteria
3+
from . import (
4+
acquisition_functions,
5+
belief_updates,
6+
initial_designs,
7+
policies,
8+
stopping_criteria,
9+
)
410
from ._bayesian_quadrature import BayesianQuadrature
511
from ._bq_state import BQIterInfo, BQState
612

src/probnum/quad/solvers/_bayesian_quadrature.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@
1010
from probnum.quad.integration_measures import IntegrationMeasure, LebesgueMeasure
1111
from probnum.quad.kernel_embeddings import KernelEmbedding
1212
from probnum.quad.solvers._bq_state import BQIterInfo, BQState
13+
from probnum.quad.solvers.acquisition_functions import WeightedPredictiveVariance
1314
from probnum.quad.solvers.belief_updates import BQBeliefUpdate, BQStandardBeliefUpdate
1415
from probnum.quad.solvers.initial_designs import InitialDesign, LatinDesign, MCDesign
15-
from probnum.quad.solvers.policies import Policy, RandomPolicy, VanDerCorputPolicy
16+
from probnum.quad.solvers.policies import (
17+
Policy,
18+
RandomMaxAcquisitionPolicy,
19+
RandomPolicy,
20+
VanDerCorputPolicy,
21+
)
1622
from probnum.quad.solvers.stopping_criteria import (
1723
BQStoppingCriterion,
1824
ImmediateStop,
@@ -38,7 +44,7 @@ class BayesianQuadrature:
3844
Parameters
3945
----------
4046
kernel
41-
The kernel used for the GP model.
47+
The kernel used for the Gaussian process model.
4248
measure
4349
The integration measure.
4450
policy
@@ -139,6 +145,9 @@ def from_problem(
139145
num_initial_design_nodes : Optional[IntLike]
140146
The number of nodes created by the initial design. Defaults to
141147
``input_dim * 5`` if an initial design is given.
148+
us_rand_num_candidates : Optional[IntLike]
149+
The number of candidate nodes used by the policy 'us_rand'. Defaults
150+
to 1e2.
142151
143152
Returns
144153
-------
@@ -175,6 +184,7 @@ def from_problem(
175184
num_initial_design_nodes = options.get(
176185
"num_initial_design_nodes", int(5 * input_dim)
177186
)
187+
us_rand_num_candidates = options.get("us_rand_num_candidates", int(1e2))
178188

179189
# Set up integration measure
180190
if domain is None and measure is None:
@@ -198,6 +208,12 @@ def from_problem(
198208
policy = RandomPolicy(batch_size, measure.sample)
199209
elif policy == "vdc":
200210
policy = VanDerCorputPolicy(batch_size, measure)
211+
elif policy == "us_rand":
212+
policy = RandomMaxAcquisitionPolicy(
213+
batch_size=1,
214+
acquisition_func=WeightedPredictiveVariance(),
215+
n_candidates=us_rand_num_candidates,
216+
)
201217
else:
202218
raise NotImplementedError(f"The given policy ({policy}) is unknown.")
203219

src/probnum/quad/solvers/_bq_state.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class BQState:
3737
Function evaluations at nodes.
3838
gram
3939
The kernel Gram matrix.
40+
gram_cho_factor
41+
The output of BQBeliefUpdate.compute_gram_cho_factor.
4042
kernel_means
4143
All kernel mean evaluations at ``nodes``.
4244
@@ -55,6 +57,7 @@ def __init__(
5557
nodes: Optional[np.ndarray] = None,
5658
fun_evals: Optional[np.ndarray] = None,
5759
gram: np.ndarray = np.array([[]]),
60+
gram_cho_factor: Tuple[np.ndarray, bool] = (np.array([[]]), False),
5861
kernel_means: np.ndarray = np.array([]),
5962
):
6063
self.measure = measure
@@ -73,6 +76,7 @@ def __init__(
7376
self.fun_evals = fun_evals
7477

7578
self.gram = gram
79+
self.gram_cho_factor = gram_cho_factor
7680
self.kernel_means = kernel_means
7781

7882
@classmethod
@@ -85,6 +89,7 @@ def from_new_data(
8589
integral_belief: Normal,
8690
prev_state: "BQState",
8791
gram: np.ndarray,
92+
gram_cho_factor: Tuple[np.ndarray, bool],
8893
kernel_means: np.ndarray,
8994
) -> "BQState":
9095
r"""Initialize state from updated data.
@@ -105,6 +110,8 @@ def from_new_data(
105110
Previous state of the BQ loop.
106111
gram
107112
The Gram matrix of the given nodes.
113+
gram_cho_factor
114+
The output of BQBeliefUpdate.compute_gram_cho_factor for ``gram``.
108115
kernel_means
109116
The kernel means at the given nodes.
110117
@@ -123,6 +130,7 @@ def from_new_data(
123130
nodes=nodes,
124131
fun_evals=fun_evals,
125132
gram=gram,
133+
gram_cho_factor=gram_cho_factor,
126134
kernel_means=kernel_means,
127135
)
128136

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Acquisition functions for Bayesian quadrature."""
2+
3+
from ._acquisition_function import AcquisitionFunction
4+
from ._predictive_variance import WeightedPredictiveVariance
5+
6+
# Public classes and functions. Order is reflected in documentation.
7+
__all__ = [
8+
"AcquisitionFunction",
9+
"WeightedPredictiveVariance",
10+
]
11+
12+
# Set correct module paths. Corrects links and module paths in documentation.
13+
WeightedPredictiveVariance.__module__ = "probnum.quad.solvers.acquisition_functions"
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Abstract base class for BQ acquisition functions."""
2+
3+
from __future__ import annotations
4+
5+
import abc
6+
from typing import Optional, Tuple
7+
8+
import numpy as np
9+
10+
from probnum.quad.solvers._bq_state import BQState
11+
12+
13+
class AcquisitionFunction(abc.ABC):
14+
"""An abstract class for an acquisition function for Bayesian quadrature."""
15+
16+
@property
17+
@abc.abstractmethod
18+
def has_gradients(self) -> bool:
19+
"""Whether the acquisition function exposes gradients."""
20+
raise NotImplementedError
21+
22+
@abc.abstractmethod
23+
def __call__(
24+
self, x: np.ndarray, bq_state: BQState
25+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
26+
"""Evaluates the acquisition function and optionally its gradients.
27+
28+
Parameters
29+
----------
30+
x
31+
*shape=(batch_size, input_dim)* -- The nodes where the acquisition function
32+
is being evaluated.
33+
bq_state
34+
State of the BQ belief.
35+
36+
Returns
37+
-------
38+
acquisition_values :
39+
*shape=(batch_size, )* -- The acquisition values at nodes ``x``.
40+
acquisition_gradients :
41+
*shape=(batch_size, input_dim)* -- The corresponding gradients (optional).
42+
"""
43+
raise NotImplementedError
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Uncertainty sampling for Bayesian Monte Carlo."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Optional, Tuple
6+
7+
import numpy as np
8+
9+
from probnum.quad.solvers._bq_state import BQState
10+
from probnum.quad.solvers.belief_updates import BQStandardBeliefUpdate
11+
12+
from ._acquisition_function import AcquisitionFunction
13+
14+
# pylint: disable=too-few-public-methods, fixme
15+
16+
17+
class WeightedPredictiveVariance(AcquisitionFunction):
18+
r"""The predictive variance acquisition function that yields uncertainty sampling.
19+
20+
The acquisition function is
21+
22+
.. math::
23+
a(x) = \operatorname{Var}(f(x)) p(x)^2
24+
25+
where :math:`\operatorname{Var}(f(x))` is the predictive variance of the model and
26+
:math:`p(x)` is the density of the integration measure :math:`\mu`.
27+
28+
"""
29+
30+
@property
31+
def has_gradients(self) -> bool:
32+
# Todo (#581): this needs to return True, once gradients are available
33+
return False
34+
35+
def __call__(
36+
self,
37+
x: np.ndarray,
38+
bq_state: BQState,
39+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
40+
predictive_variance = bq_state.kernel(x, x)
41+
if bq_state.fun_evals.shape != (0,):
42+
kXx = bq_state.kernel.matrix(bq_state.nodes, x)
43+
regression_weights = BQStandardBeliefUpdate.gram_cho_solve(
44+
bq_state.gram_cho_factor, kXx
45+
)
46+
predictive_variance -= np.sum(regression_weights * kXx, axis=0)
47+
values = bq_state.scale_sq * predictive_variance * bq_state.measure(x) ** 2
48+
return values, None

src/probnum/quad/solvers/belief_updates/_belief_update.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __call__(
5858
"""
5959
raise NotImplementedError
6060

61-
def _compute_gram_cho_factor(self, gram: np.ndarray) -> np.ndarray:
61+
def compute_gram_cho_factor(self, gram: np.ndarray) -> Tuple[np.ndarray, bool]:
6262
"""Compute the Cholesky decomposition of a positive-definite Gram matrix for use
6363
in scipy.linalg.cho_solve
6464
@@ -68,19 +68,36 @@ def _compute_gram_cho_factor(self, gram: np.ndarray) -> np.ndarray:
6868
6969
Parameters
7070
----------
71-
gram :
71+
gram
7272
symmetric pos. def. kernel Gram matrix :math:`K`, shape (nevals, nevals)
7373
7474
Returns
7575
-------
7676
gram_cho_factor :
7777
The upper triangular Cholesky decomposition of the Gram matrix. Other
78-
parts of the matrix contain random data.
78+
parts of the matrix contain random data. A boolean that indicates whether
79+
the matrix is lower triangular (always False but needed for scipy).
7980
"""
8081
return cho_factor(gram + self.jitter * np.eye(gram.shape[0]))
8182

82-
# pylint: disable=no-self-use
83-
def _gram_cho_solve(self, gram_cho_factor: np.ndarray, z: np.ndarray) -> np.ndarray:
83+
@staticmethod
84+
def gram_cho_solve(
85+
gram_cho_factor: Tuple[np.ndarray, bool], z: np.ndarray
86+
) -> np.ndarray:
8487
"""Wrapper for scipy.linalg.cho_solve. Meant to be used for linear systems of
85-
the gram matrix. Requires the solution of scipy.linalg.cho_factor as input."""
88+
the gram matrix. Requires the solution of scipy.linalg.cho_factor as input.
89+
90+
Parameters
91+
----------
92+
gram_cho_factor
93+
The return object of compute_gram_cho_factor.
94+
z
95+
An array of appropriate shape.
96+
97+
Returns
98+
-------
99+
solution :
100+
The solution ``x`` to the linear system ``gram x = z``.
101+
102+
"""
86103
return cho_solve(gram_cho_factor, z)

0 commit comments

Comments
 (0)