Skip to content

Commit 8dde541

Browse files
authored
Make quad policies stateless (#744)
1 parent 92e59ad commit 8dde541

File tree

9 files changed

+242
-68
lines changed

9 files changed

+242
-68
lines changed

src/probnum/quad/_bayesquad.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def bayesquad(
3333
var_tol: Optional[FloatLike] = None,
3434
rel_tol: Optional[FloatLike] = None,
3535
batch_size: IntLike = 1,
36-
rng: Optional[np.random.Generator] = np.random.default_rng(),
36+
rng: Optional[np.random.Generator] = None,
3737
jitter: FloatLike = 1.0e-8,
3838
) -> Tuple[Normal, BQIterInfo]:
3939
r"""Infer the solution of the uni- or multivariate integral
@@ -100,7 +100,7 @@ def bayesquad(
100100
Number of new observations at each update. Defaults to 1.
101101
rng
102102
Random number generator. Used by Bayesian Monte Carlo other random sampling
103-
policies. Optional. Default is `np.random.default_rng()`.
103+
policies.
104104
jitter
105105
Non-negative jitter to numerically stabilise kernel matrix inversion.
106106
Defaults to 1e-8.
@@ -145,9 +145,9 @@ def bayesquad(
145145
146146
>>> input_dim = 1
147147
>>> domain = (0, 1)
148-
>>> def f(x):
148+
>>> def fun(x):
149149
... return x.reshape(-1, )
150-
>>> F, info = bayesquad(fun=f, input_dim=input_dim, domain=domain)
150+
>>> F, info = bayesquad(fun, input_dim, domain=domain, rng=np.random.default_rng(0))
151151
>>> print(F.mean)
152152
0.5
153153
"""
@@ -167,12 +167,13 @@ def bayesquad(
167167
var_tol=var_tol,
168168
rel_tol=rel_tol,
169169
batch_size=batch_size,
170-
rng=rng,
171170
jitter=jitter,
172171
)
173172

174173
# Integrate
175-
integral_belief, _, info = bq_method.integrate(fun=fun, nodes=None, fun_evals=None)
174+
integral_belief, _, info = bq_method.integrate(
175+
fun=fun, nodes=None, fun_evals=None, rng=rng
176+
)
176177

177178
return integral_belief, info
178179

@@ -261,7 +262,7 @@ def bayesquad_from_data(
261262

262263
# Integrate
263264
integral_belief, _, info = bq_method.integrate(
264-
fun=None, nodes=nodes, fun_evals=fun_evals
265+
fun=None, nodes=nodes, fun_evals=fun_evals, rng=None
265266
)
266267

267268
return integral_belief, info

src/probnum/quad/solvers/_bayesian_quadrature.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def from_problem(
8383
var_tol: Optional[FloatLike] = None,
8484
rel_tol: Optional[FloatLike] = None,
8585
batch_size: IntLike = 1,
86-
rng: np.random.Generator = None,
8786
jitter: FloatLike = 1.0e-8,
8887
) -> "BayesianQuadrature":
8988

@@ -112,8 +111,6 @@ def from_problem(
112111
Relative tolerance as stopping criterion.
113112
batch_size
114113
Batch size used in node acquisition. Defaults to 1.
115-
rng
116-
The random number generator.
117114
jitter
118115
Non-negative jitter to numerically stabilise kernel matrix inversion.
119116
Defaults to 1e-8.
@@ -127,9 +124,6 @@ def from_problem(
127124
------
128125
ValueError
129126
If neither a ``domain`` nor a ``measure`` are given.
130-
ValueError
131-
If Bayesian Monte Carlo ('bmc') is selected as ``policy`` and no random
132-
number generator (``rng``) is given.
133127
NotImplementedError
134128
If an unknown ``policy`` is given.
135129
"""
@@ -153,15 +147,9 @@ def from_problem(
153147
# require an acquisition loop. The error handling is done in ``integrate``.
154148
pass
155149
elif policy == "bmc":
156-
if rng is None:
157-
errormsg = (
158-
"Policy 'bmc' relies on random sampling, "
159-
"thus requires a random number generator ('rng')."
160-
)
161-
raise ValueError(errormsg)
162-
policy = RandomPolicy(measure.sample, batch_size=batch_size, rng=rng)
150+
policy = RandomPolicy(batch_size, measure.sample)
163151
elif policy == "vdc":
164-
policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size)
152+
policy = VanDerCorputPolicy(batch_size, measure)
165153
else:
166154
raise NotImplementedError(f"The given policy ({policy}) is unknown.")
167155

@@ -215,6 +203,7 @@ def bq_iterator(
215203
bq_state: BQState,
216204
info: Optional[BQIterInfo],
217205
fun: Optional[Callable],
206+
rng: Optional[np.random.Generator],
218207
) -> Tuple[Normal, BQState, BQIterInfo]:
219208
"""Generator that implements the iteration of the BQ method.
220209
@@ -231,6 +220,8 @@ def bq_iterator(
231220
fun
232221
Function to be integrated. It needs to accept a shape=(n_eval, input_dim)
233222
``np.ndarray`` and return a shape=(n_eval,) ``np.ndarray``.
223+
rng
224+
The random number generator used for random methods.
234225
235226
Yields
236227
------
@@ -258,7 +249,7 @@ def bq_iterator(
258249
break
259250

260251
# Select new nodes via policy
261-
new_nodes = self.policy(bq_state=bq_state)
252+
new_nodes = self.policy(bq_state, rng)
262253

263254
# Evaluate the integrand at new nodes
264255
new_fun_evals = fun(new_nodes)
@@ -278,6 +269,7 @@ def integrate(
278269
fun: Optional[Callable],
279270
nodes: Optional[np.ndarray],
280271
fun_evals: Optional[np.ndarray],
272+
rng: Optional[np.random.Generator] = None,
281273
) -> Tuple[Normal, BQState, BQIterInfo]:
282274
"""Integrates the function ``fun``.
283275
@@ -297,6 +289,8 @@ def integrate(
297289
fun_evals
298290
*shape=(n_eval,)* -- Optional function evaluations at ``nodes`` available
299291
from the start.
292+
rng
293+
The random number generator used for random methods.
300294
301295
Returns
302296
-------
@@ -308,14 +302,17 @@ def integrate(
308302
Raises
309303
------
310304
ValueError
311-
If neither the integrand function (``fun``) nor integrand evaluations
312-
(``fun_evals``) are given.
305+
If neither the integrand function ``fun`` nor integrand evaluations
306+
``fun_evals`` are given.
313307
ValueError
314-
If ``nodes`` are not given and no policy is present.
308+
If neither ``nodes`` nor ``policy`` is given.
315309
ValueError
316310
If dimension of ``nodes`` or ``fun_evals`` is incorrect, or if their
317311
shapes do not match.
312+
ValueError
313+
If ``rng`` is not given but ``policy`` requires it.
318314
"""
315+
319316
# no policy given: Integrate on fixed dataset.
320317
if self.policy is None:
321318
# nodes must be provided if no policy is given.
@@ -325,13 +322,19 @@ def integrate(
325322
# Use fun_evals and disregard fun if both are given
326323
if fun is not None and fun_evals is not None:
327324
warnings.warn(
328-
"No policy available: 'fun_eval' are used instead of 'fun'."
325+
"No policy available: 'fun_evals' are used instead of 'fun'."
329326
)
330327
fun = None
331328

332329
# override stopping condition as no policy is given.
333330
self.stopping_criterion = ImmediateStop()
334331

332+
elif self.policy.requires_rng and rng is None:
333+
raise ValueError(
334+
f"The policy '{self.policy.__class__.__name__}' requires a random "
335+
f"number generator (rng) to be given."
336+
)
337+
335338
# Check if integrand function is provided
336339
if fun is None and fun_evals is None:
337340
raise ValueError(
@@ -375,7 +378,7 @@ def integrate(
375378
)
376379

377380
info = None
378-
for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun):
381+
for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun, rng):
379382
pass
380383

381384
return bq_state.integral_belief, bq_state, info

src/probnum/quad/solvers/policies/_policy.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Abstract base class for BQ policies."""
22

3+
from __future__ import annotations
4+
35
import abc
6+
from typing import Optional
47

58
import numpy as np
69

710
from probnum.quad.solvers._bq_state import BQState
11+
from probnum.typing import IntLike
812

913
# pylint: disable=too-few-public-methods, fixme
1014

@@ -18,17 +22,28 @@ class Policy(abc.ABC):
1822
Size of batch of nodes when calling the policy once.
1923
"""
2024

21-
def __init__(self, batch_size: int) -> None:
22-
self.batch_size = batch_size
25+
def __init__(self, batch_size: IntLike) -> None:
26+
self.batch_size = int(batch_size)
2327

28+
@property
2429
@abc.abstractmethod
25-
def __call__(self, bq_state: BQState) -> np.ndarray:
30+
def requires_rng(self) -> bool:
31+
"""Whether the policy requires a random number generator when called."""
32+
raise NotImplementedError
33+
34+
@abc.abstractmethod
35+
def __call__(
36+
self, bq_state: BQState, rng: Optional[np.random.Generator]
37+
) -> np.ndarray:
2638
"""Find nodes according to the policy.
2739
2840
Parameters
2941
----------
3042
bq_state
3143
State of the BQ belief.
44+
rng
45+
A random number generator.
46+
3247
Returns
3348
-------
3449
nodes :
Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Random policy for Bayesian Monte Carlo."""
22

3-
from typing import Callable
3+
from __future__ import annotations
4+
5+
from typing import Callable, Optional
46

57
import numpy as np
68

79
from probnum.quad.solvers._bq_state import BQState
10+
from probnum.typing import IntLike
811

912
from ._policy import Policy
1013

@@ -16,25 +19,27 @@ class RandomPolicy(Policy):
1619
1720
Parameters
1821
----------
22+
batch_size
23+
Size of batch of nodes when calling the policy once.
1924
sample_func
2025
The sample function. Needs to have the following interface:
2126
`sample_func(batch_size: int, rng: np.random.Generator)` and return an array of
22-
shape (batch_size, n_dim).
23-
batch_size
24-
Size of batch of nodes when calling the policy once.
25-
rng
26-
A random number generator.
27+
shape (batch_size, input_dim).
2728
"""
2829

2930
def __init__(
3031
self,
32+
batch_size: IntLike,
3133
sample_func: Callable,
32-
batch_size: int,
33-
rng: np.random.Generator = np.random.default_rng(),
3434
) -> None:
3535
super().__init__(batch_size=batch_size)
3636
self.sample_func = sample_func
37-
self.rng = rng
3837

39-
def __call__(self, bq_state: BQState) -> np.ndarray:
40-
return self.sample_func(self.batch_size, rng=self.rng)
38+
@property
39+
def requires_rng(self) -> bool:
40+
return True
41+
42+
def __call__(
43+
self, bq_state: BQState, rng: Optional[np.random.Generator]
44+
) -> np.ndarray:
45+
return self.sample_func(self.batch_size, rng=rng)

src/probnum/quad/solvers/policies/_van_der_corput_policy.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Van der Corput points for integration on 1D intervals."""
22

3+
from __future__ import annotations
4+
35
from typing import Optional
46

57
import numpy as np
68

79
from probnum.quad.integration_measures import IntegrationMeasure
810
from probnum.quad.solvers._bq_state import BQState
11+
from probnum.typing import IntLike
912

1013
from ._policy import Policy
1114

@@ -22,17 +25,17 @@ class VanDerCorputPolicy(Policy):
2225
2326
Parameters
2427
----------
25-
measure
26-
The integration measure with finite domain.
2728
batch_size
2829
Size of batch of nodes when calling the policy once.
30+
measure
31+
The integration measure with finite domain.
2932
3033
References
3134
--------
3235
.. [1] https://en.wikipedia.org/wiki/Van_der_Corput_sequence
3336
"""
3437

35-
def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None:
38+
def __init__(self, batch_size: IntLike, measure: IntegrationMeasure) -> None:
3639
super().__init__(batch_size=batch_size)
3740

3841
if int(measure.input_dim) > 1:
@@ -46,7 +49,13 @@ def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None:
4649
self.domain_a = domain_a
4750
self.domain_b = domain_b
4851

49-
def __call__(self, bq_state: BQState) -> np.ndarray:
52+
@property
53+
def requires_rng(self) -> bool:
54+
return False
55+
56+
def __call__(
57+
self, bq_state: BQState, rng: Optional[np.random.Generator]
58+
) -> np.ndarray:
5059
n_nodes = bq_state.nodes.shape[0]
5160
vdc_seq = VanDerCorputPolicy.van_der_corput_sequence(
5261
n_nodes + 1, n_nodes + 1 + self.batch_size

src/probnum/quad/solvers/stopping_criteria/_max_nevals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class MaxNevals(BQStoppingCriterion):
1919
"""
2020

2121
def __init__(self, max_nevals: IntLike):
22-
self.max_nevals = max_nevals
22+
self.max_nevals = int(max_nevals)
2323

2424
def __call__(self, bq_state: BQState, info: BQIterInfo) -> bool:
2525
return info.nevals >= self.max_nevals

0 commit comments

Comments
 (0)