From 73ea344820856b98acd837fde32fde435d1d150a Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Mon, 21 Nov 2022 14:32:08 +0100 Subject: [PATCH 01/19] adding info to kernel embeddings docs --- .../kernel_embeddings/_kernel_embedding.py | 104 ++++++++++-------- 1 file changed, 61 insertions(+), 43 deletions(-) diff --git a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py index be0f742ac..dad4bedc5 100644 --- a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py +++ b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py @@ -25,6 +25,24 @@ class KernelEmbedding: """Integrals over kernels against integration measures. + The available kernel embeddings are: + + ============= =============== + ExpQuad LebesgueMeasure + ============= =============== + + ============= =============== + ExpQuad GaussianMeasure + ============= =============== + + ============= =============== + Matern (1d) LebesgueMeasure + ============= =============== + + ============= =============== + ProductMatern LebesgueMeasure + ============= =============== + Parameters ---------- kernel: @@ -51,7 +69,7 @@ def __init__(self, kernel: Kernel, measure: IntegrationMeasure) -> None: (self.input_dim,) = self.kernel.input_shape # retrieve the functions for the provided combination of kernel and measure - self._kmean, self._kvar = _get_kernel_embedding( + self._kmean, self._kvar = self._get_kernel_embedding( kernel=self.kernel, measure=self.measure ) @@ -82,48 +100,48 @@ def kernel_variance(self) -> float: """ return self._kvar(kernel=self.kernel, measure=self.measure) + @staticmethod + def _get_kernel_embedding( + kernel: Kernel, measure: IntegrationMeasure + ) -> Tuple[Callable, Callable]: + """Select the right kernel embedding given the kernel and integration measure. -def _get_kernel_embedding( - kernel: Kernel, measure: IntegrationMeasure -) -> Tuple[Callable, Callable]: - """Select the right kernel embedding given the kernel and integration measure. - - Parameters - ---------- - kernel : - Instance of a kernel. - measure : - Instance of an integration measure. - - Returns - ------- - kernel_mean : - The kernel mean function. - kernel_variance : - The kernel variance function. + Parameters + ---------- + kernel : + Instance of a kernel. + measure : + Instance of an integration measure. - Raises - ------ - NotImplementedError - If the given kernel is unknown. - NotImplementedError - If the kernel embedding of the kernel-measure pair is unknown. - """ + Returns + ------- + kernel_mean : + The kernel mean function. + kernel_variance : + The kernel variance function. + + Raises + ------ + NotImplementedError + If the given kernel is unknown. + NotImplementedError + If the kernel embedding of the kernel-measure pair is unknown. + """ - # Exponentiated quadratic kernel - if isinstance(kernel, ExpQuad): - if isinstance(measure, GaussianMeasure): - return _kernel_mean_expquad_gauss, _kernel_variance_expquad_gauss - if isinstance(measure, LebesgueMeasure): - return _kernel_mean_expquad_lebesgue, _kernel_variance_expquad_lebesgue - - # Matern - if isinstance(kernel, (Matern, ProductMatern)): - if isinstance(measure, LebesgueMeasure): - return _kernel_mean_matern_lebesgue, _kernel_variance_matern_lebesgue - - # other kernels - raise NotImplementedError( - "The combination of kernel ({0}) and measure ({1}) is not available as kernel " - "embedding.".format(type(kernel), type(measure)) - ) + # Exponentiated quadratic kernel + if isinstance(kernel, ExpQuad): + if isinstance(measure, GaussianMeasure): + return _kernel_mean_expquad_gauss, _kernel_variance_expquad_gauss + if isinstance(measure, LebesgueMeasure): + return _kernel_mean_expquad_lebesgue, _kernel_variance_expquad_lebesgue + + # Matern + if isinstance(kernel, (Matern, ProductMatern)): + if isinstance(measure, LebesgueMeasure): + return _kernel_mean_matern_lebesgue, _kernel_variance_matern_lebesgue + + # other kernels + raise NotImplementedError( + "The combination of kernel ({0}) and measure ({1}) is not available as kernel " + "embedding.".format(type(kernel), type(measure)) + ) From f19568793b8dcffc0d4dbc86d17fb51ee8443c57 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Mon, 21 Nov 2022 14:46:54 +0100 Subject: [PATCH 02/19] resolving one of the pylint errors --- src/probnum/quad/kernel_embeddings/_kernel_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py index dad4bedc5..d02916f5f 100644 --- a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py +++ b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py @@ -142,6 +142,6 @@ def _get_kernel_embedding( # other kernels raise NotImplementedError( - "The combination of kernel ({0}) and measure ({1}) is not available as kernel " - "embedding.".format(type(kernel), type(measure)) + "The combination of kernel ({0}) and measure ({1}) is not available as " + "kernel embedding.".format(type(kernel), type(measure)) ) From dbedd41120612eaa92f380f8b9b810b0929c1c10 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Mon, 21 Nov 2022 20:05:28 +0100 Subject: [PATCH 03/19] better table --- .../quad/kernel_embeddings/_kernel_embedding.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py index d02916f5f..6f7a90389 100644 --- a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py +++ b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py @@ -29,18 +29,9 @@ class KernelEmbedding: ============= =============== ExpQuad LebesgueMeasure - ============= =============== - - ============= =============== ExpQuad GaussianMeasure - ============= =============== - - ============= =============== - Matern (1d) LebesgueMeasure - ============= =============== - - ============= =============== - ProductMatern LebesgueMeasure + Matern (1d) LebesgueMeasure + ProductMatern LebesgueMeasure ============= =============== Parameters From d6a553c820766dedd68536467dd6b2aa64eeaf0d Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Mon, 21 Nov 2022 20:51:17 +0100 Subject: [PATCH 04/19] stateless policy --- src/probnum/quad/_bayesquad.py | 7 ++-- .../kernel_embeddings/_kernel_embedding.py | 2 +- .../quad/solvers/_bayesian_quadrature.py | 32 +++++++++---------- src/probnum/quad/solvers/policies/_policy.py | 8 ++++- .../quad/solvers/policies/_random_policy.py | 11 ++++--- .../policies/_van_der_corput_policy.py | 4 ++- tests/test_quad/test_bayesian_quadrature.py | 5 +-- tests/test_quad/test_bayesquad/test_bq.py | 12 +++++-- 8 files changed, 46 insertions(+), 35 deletions(-) diff --git a/src/probnum/quad/_bayesquad.py b/src/probnum/quad/_bayesquad.py index ffb185dea..df43c72f4 100644 --- a/src/probnum/quad/_bayesquad.py +++ b/src/probnum/quad/_bayesquad.py @@ -167,12 +167,13 @@ def bayesquad( var_tol=var_tol, rel_tol=rel_tol, batch_size=batch_size, - rng=rng, jitter=jitter, ) # Integrate - integral_belief, _, info = bq_method.integrate(fun=fun, nodes=None, fun_evals=None) + integral_belief, _, info = bq_method.integrate( + fun=fun, nodes=None, fun_evals=None, rng=rng + ) return integral_belief, info @@ -261,7 +262,7 @@ def bayesquad_from_data( # Integrate integral_belief, _, info = bq_method.integrate( - fun=None, nodes=nodes, fun_evals=fun_evals + fun=None, nodes=nodes, fun_evals=fun_evals, rng=None ) return integral_belief, info diff --git a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py index 6f7a90389..57cf7ca45 100644 --- a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py +++ b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py @@ -32,7 +32,7 @@ class KernelEmbedding: ExpQuad GaussianMeasure Matern (1d) LebesgueMeasure ProductMatern LebesgueMeasure - ============= =============== + ============= =============== Parameters ---------- diff --git a/src/probnum/quad/solvers/_bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py index 6b037ecba..a73a1f76d 100644 --- a/src/probnum/quad/solvers/_bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union import warnings import numpy as np @@ -83,7 +83,6 @@ def from_problem( var_tol: Optional[FloatLike] = None, rel_tol: Optional[FloatLike] = None, batch_size: IntLike = 1, - rng: np.random.Generator = None, jitter: FloatLike = 1.0e-8, ) -> "BayesianQuadrature": @@ -112,8 +111,6 @@ def from_problem( Relative tolerance as stopping criterion. batch_size Batch size used in node acquisition. Defaults to 1. - rng - The random number generator. jitter Non-negative jitter to numerically stabilise kernel matrix inversion. Defaults to 1e-8. @@ -127,9 +124,6 @@ def from_problem( ------ ValueError If neither a ``domain`` nor a ``measure`` are given. - ValueError - If Bayesian Monte Carlo ('bmc') is selected as ``policy`` and no random - number generator (``rng``) is given. NotImplementedError If an unknown ``policy`` is given. """ @@ -153,13 +147,7 @@ def from_problem( # require an acquisition loop. The error handling is done in ``integrate``. pass elif policy == "bmc": - if rng is None: - errormsg = ( - "Policy 'bmc' relies on random sampling, " - "thus requires a random number generator ('rng')." - ) - raise ValueError(errormsg) - policy = RandomPolicy(measure.sample, batch_size=batch_size, rng=rng) + policy = RandomPolicy(measure.sample, batch_size=batch_size) elif policy == "vdc": policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size) else: @@ -215,7 +203,8 @@ def bq_iterator( bq_state: BQState, info: Optional[BQIterInfo], fun: Optional[Callable], - ) -> Tuple[Normal, BQState, BQIterInfo]: + rng: np.random.Generator, + ) -> Tuple[Normal, BQState, BQIterInfo, np.random.Generator]: """Generator that implements the iteration of the BQ method. This function exposes the state of the BQ method one step at a time while @@ -231,6 +220,8 @@ def bq_iterator( fun Function to be integrated. It needs to accept a shape=(n_eval, input_dim) ``np.ndarray`` and return a shape=(n_eval,) ``np.ndarray``. + rng + The random number generator used for random methods. Yields ------ @@ -258,7 +249,7 @@ def bq_iterator( break # Select new nodes via policy - new_nodes = self.policy(bq_state=bq_state) + new_nodes = self.policy(bq_state=bq_state, rng=rng) # Evaluate the integrand at new nodes new_fun_evals = fun(new_nodes) @@ -278,6 +269,7 @@ def integrate( fun: Optional[Callable], nodes: Optional[np.ndarray], fun_evals: Optional[np.ndarray], + rng: Union[IntLike, np.random.Generator] = np.random.default_rng(), ) -> Tuple[Normal, BQState, BQIterInfo]: """Integrates the function ``fun``. @@ -297,6 +289,8 @@ def integrate( fun_evals *shape=(n_eval,)* -- Optional function evaluations at ``nodes`` available from the start. + rng + The random number generator used for random methods, or a seed. Returns ------- @@ -316,6 +310,10 @@ def integrate( If dimension of ``nodes`` or ``fun_evals`` is incorrect, or if their shapes do not match. """ + # Get the rng + if isinstance(rng, IntLike): + rng = np.random.default_rng(int(rng)) + # no policy given: Integrate on fixed dataset. if self.policy is None: # nodes must be provided if no policy is given. @@ -375,7 +373,7 @@ def integrate( ) info = None - for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun): + for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun, rng): pass return bq_state.integral_belief, bq_state, info diff --git a/src/probnum/quad/solvers/policies/_policy.py b/src/probnum/quad/solvers/policies/_policy.py index f57d3438e..d5342b565 100644 --- a/src/probnum/quad/solvers/policies/_policy.py +++ b/src/probnum/quad/solvers/policies/_policy.py @@ -1,6 +1,7 @@ """Abstract base class for BQ policies.""" import abc +from typing import Optional import numpy as np @@ -22,13 +23,18 @@ def __init__(self, batch_size: int) -> None: self.batch_size = batch_size @abc.abstractmethod - def __call__(self, bq_state: BQState) -> np.ndarray: + def __call__( + self, bq_state: BQState, rng: Optional[np.random.Generator] + ) -> np.ndarray: """Find nodes according to the policy. Parameters ---------- bq_state State of the BQ belief. + rng + A random number generator. + Returns ------- nodes : diff --git a/src/probnum/quad/solvers/policies/_random_policy.py b/src/probnum/quad/solvers/policies/_random_policy.py index 6110ca465..96f7e8d0a 100644 --- a/src/probnum/quad/solvers/policies/_random_policy.py +++ b/src/probnum/quad/solvers/policies/_random_policy.py @@ -1,6 +1,6 @@ """Random policy for Bayesian Monte Carlo.""" -from typing import Callable +from typing import Callable, Optional import numpy as np @@ -22,17 +22,18 @@ class RandomPolicy(Policy): shape (batch_size, n_dim). batch_size Size of batch of nodes when calling the policy once. + """ def __init__( self, sample_func: Callable, batch_size: int, - rng: np.random.Generator = np.random.default_rng(), ) -> None: super().__init__(batch_size=batch_size) self.sample_func = sample_func - self.rng = rng - def __call__(self, bq_state: BQState) -> np.ndarray: - return self.sample_func(self.batch_size, rng=self.rng) + def __call__( + self, bq_state: BQState, rng: Optional[np.random.Generator] + ) -> np.ndarray: + return self.sample_func(self.batch_size, rng=rng) diff --git a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py index 1e7edfe85..8656881c0 100644 --- a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py +++ b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py @@ -46,7 +46,9 @@ def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None: self.domain_a = domain_a self.domain_b = domain_b - def __call__(self, bq_state: BQState) -> np.ndarray: + def __call__( + self, bq_state: BQState, rng: Optional[np.random.Generator] + ) -> np.ndarray: n_nodes = bq_state.nodes.shape[0] vdc_seq = VanDerCorputPolicy.van_der_corput_sequence( n_nodes + 1, n_nodes + 1 + self.batch_size diff --git a/tests/test_quad/test_bayesian_quadrature.py b/tests/test_quad/test_bayesian_quadrature.py index 9af6fd4a9..ea8577903 100644 --- a/tests/test_quad/test_bayesian_quadrature.py +++ b/tests/test_quad/test_bayesian_quadrature.py @@ -31,7 +31,6 @@ def bq(input_dim): return BayesianQuadrature.from_problem( input_dim=input_dim, domain=(np.zeros(input_dim), np.ones(input_dim)), - rng=np.random.default_rng(), ) @@ -56,9 +55,7 @@ def test_bq_from_problem_wrong_inputs(input_dim): ) def test_bq_from_problem_policy_assignment(policy, policy_type): """Test if correct policy is assigned from string identifier.""" - bq = BayesianQuadrature.from_problem( - input_dim=1, domain=(0, 1), policy=policy, rng=np.random.default_rng() - ) + bq = BayesianQuadrature.from_problem(input_dim=1, domain=(0, 1), policy=policy) assert isinstance(bq.policy, policy_type) diff --git a/tests/test_quad/test_bayesquad/test_bq.py b/tests/test_quad/test_bayesquad/test_bq.py index 704c62976..37d08a200 100644 --- a/tests/test_quad/test_bayesquad/test_bq.py +++ b/tests/test_quad/test_bayesquad/test_bq.py @@ -18,10 +18,15 @@ def rng(): @pytest.mark.parametrize("input_dim", [1], ids=["dim1"]) -def test_type_1d(f1d, kernel, measure, input_dim): +def test_type_1d(f1d, kernel, measure, input_dim, rng): """Test that BQ outputs normal random variables for 1D integrands.""" integral, _ = bayesquad( - fun=f1d, input_dim=input_dim, kernel=kernel, measure=measure, max_evals=10 + fun=f1d, + input_dim=input_dim, + kernel=kernel, + measure=measure, + max_evals=10, + rng=rng, ) assert isinstance(integral, Normal) @@ -43,7 +48,7 @@ def test_type_1d(f1d, kernel, measure, input_dim): @pytest.mark.parametrize("scale_estimation", [None, "mle"]) @pytest.mark.parametrize("jitter", [1e-6, 1e-7]) def test_integral_values_1d( - f1d, kernel, domain, input_dim, scale_estimation, var_tol, rel_tol, jitter + f1d, kernel, domain, input_dim, scale_estimation, var_tol, rel_tol, jitter, rng ): """Test numerically that BQ computes 1D integrals correctly for a number of different parameters. @@ -70,6 +75,7 @@ def integrand(x): var_tol=var_tol, rel_tol=rel_tol, jitter=jitter, + rng=rng, ) domain = measure.domain num_integral, _ = scipyquad(integrand, domain[0], domain[1]) From 9b4f2fdd659140d01a0b8d5e0b92e47980b7bd90 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Mon, 21 Nov 2022 21:01:11 +0100 Subject: [PATCH 05/19] small change --- src/probnum/quad/solvers/_bayesian_quadrature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probnum/quad/solvers/_bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py index a73a1f76d..bd2d82f62 100644 --- a/src/probnum/quad/solvers/_bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -204,7 +204,7 @@ def bq_iterator( info: Optional[BQIterInfo], fun: Optional[Callable], rng: np.random.Generator, - ) -> Tuple[Normal, BQState, BQIterInfo, np.random.Generator]: + ) -> Tuple[Normal, BQState, BQIterInfo]: """Generator that implements the iteration of the BQ method. This function exposes the state of the BQ method one step at a time while From 546189c98d390e853587f9a973633e8b7795eb92 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Tue, 22 Nov 2022 15:35:49 +0100 Subject: [PATCH 06/19] small fix --- src/probnum/quad/solvers/_bayesian_quadrature.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/probnum/quad/solvers/_bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py index bd2d82f62..b49ab4295 100644 --- a/src/probnum/quad/solvers/_bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, get_args import warnings import numpy as np @@ -311,7 +311,7 @@ def integrate( shapes do not match. """ # Get the rng - if isinstance(rng, IntLike): + if isinstance(rng, get_args(IntLike)): rng = np.random.default_rng(int(rng)) # no policy given: Integrate on fixed dataset. From bbc3148b592fc2833880d0e4bbd5845c2ac78f0c Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Tue, 22 Nov 2022 16:26:47 +0100 Subject: [PATCH 07/19] fix tests --- .../solvers/stopping_criteria/_max_nevals.py | 2 +- tests/test_quad/test_bayesian_quadrature.py | 42 ++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py index bf87dd252..9d50d19c9 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py +++ b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py @@ -19,7 +19,7 @@ class MaxNevals(BQStoppingCriterion): """ def __init__(self, max_nevals: IntLike): - self.max_nevals = max_nevals + self.max_nevals = int(max_nevals) def __call__(self, bq_state: BQState, info: BQIterInfo) -> bool: return info.nevals >= self.max_nevals diff --git a/tests/test_quad/test_bayesian_quadrature.py b/tests/test_quad/test_bayesian_quadrature.py index ea8577903..61a2da4e7 100644 --- a/tests/test_quad/test_bayesian_quadrature.py +++ b/tests/test_quad/test_bayesian_quadrature.py @@ -7,7 +7,12 @@ from probnum.quad.integration_measures import LebesgueMeasure from probnum.quad.solvers import BayesianQuadrature from probnum.quad.solvers.policies import RandomPolicy, VanDerCorputPolicy -from probnum.quad.solvers.stopping_criteria import ImmediateStop +from probnum.quad.solvers.stopping_criteria import ( + ImmediateStop, + IntegralVarianceTolerance, + MaxNevals, + RelativeMeanChange, +) from probnum.randprocs.kernels import ExpQuad @@ -78,6 +83,30 @@ def test_bq_from_problem_defaults(bq_no_policy, bq): assert isinstance(bq.kernel, ExpQuad) +@pytest.mark.parametrize( + "max_evals, var_tol, rel_tol, t", + [ + (None, None, None, LambdaStoppingCriterion), + (1000, None, None, MaxNevals), + (None, 1e-5, None, IntegralVarianceTolerance), + (None, None, 1e-5, RelativeMeanChange), + (None, 1e-5, 1e-5, LambdaStoppingCriterion), + (1000, None, 1e-5, LambdaStoppingCriterion), + (1000, 1e-5, None, LambdaStoppingCriterion), + (1000, 1e-5, 1e-5, LambdaStoppingCriterion), + ], +) +def test_bq_from_problem_stopping_condition_assignment(max_evals, var_tol, rel_tol, t): + bq = BayesianQuadrature.from_problem( + input_dim=2, + domain=(0, 1), + max_evals=max_evals, + var_tol=var_tol, + rel_tol=rel_tol, + ) + assert isinstance(bq.stopping_criterion, t) + + def test_integrate_no_policy_wrong_input(bq_no_policy, data): nodes, fun_evals, fun = data @@ -117,3 +146,14 @@ def test_integrate_wrong_input(bq, bq_no_policy, data): bq.integrate(fun=fun, nodes=wrong_nodes, fun_evals=fun_evals) with pytest.raises(ValueError): bq_no_policy.integrate(fun=None, nodes=wrong_nodes, fun_evals=fun_evals) + + +@pytest.mark.parametrize("rng", [np.random.default_rng(42), 42]) +def test_integrate_runs_with_integer_rng(bq, data, rng): + # rng is a generator + nodes, fun_evals, fun = data + bq.integrate(fun=fun, nodes=None, fun_evals=None, rng=rng) + + # rng is a seed + nodes, fun_evals, fun = data + bq.integrate(fun=fun, nodes=None, fun_evals=None, rng=rng) From 18ae4d485433ef00ad9145035f749c25469afe99 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Tue, 22 Nov 2022 18:01:13 +0100 Subject: [PATCH 08/19] small fix --- src/probnum/quad/solvers/policies/_van_der_corput_policy.py | 2 +- tests/test_quad/test_policy.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py index 8656881c0..1218aadbc 100644 --- a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py +++ b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py @@ -40,7 +40,7 @@ def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None: domain_a = measure.domain[0] domain_b = measure.domain[1] - if np.Inf in [abs(domain_a), abs(domain_b)]: + if np.Inf in np.hstack([abs(measure.domain[0]), abs(measure.domain[1])]): raise ValueError("Policy 'vdc' works only for bounded domains.") self.domain_a = domain_a diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index fa848da73..3937ba570 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -6,6 +6,8 @@ from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure from probnum.quad.solvers.policies import VanDerCorputPolicy +# Todo: test other policies, too + def test_van_der_corput_multi_d_error(): """Check that van der Corput policy fails in dimensions higher than one.""" From a07fa0791b61a2526f0b3eabb4f38a12f397a196 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Tue, 22 Nov 2022 19:19:11 +0100 Subject: [PATCH 09/19] adding shape tests for policies --- tests/test_quad/test_policy.py | 85 +++++++++++++++++++++++++++++++++- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index 3937ba570..5528a921d 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -4,9 +4,90 @@ import pytest from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure -from probnum.quad.solvers.policies import VanDerCorputPolicy +from probnum.quad.solvers import BQState +from probnum.quad.solvers.policies import Policy, RandomPolicy, VanDerCorputPolicy +from probnum.randprocs.kernels import ExpQuad -# Todo: test other policies, too + +@pytest.fixture +def batch_size(): + return 3 + + +@pytest.fixture +def bq_state_no_data(input_dim): + return BQState( + measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)), + kernel=ExpQuad(input_shape=(input_dim,)), + ) + + +@pytest.fixture +def bq_state(input_dim): + nevals = 5 + return BQState( + measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)), + kernel=ExpQuad(input_shape=(input_dim,)), + nodes=np.zeros([nevals, input_dim]), + fun_evals=np.ones(nevals), + ) + + +@pytest.fixture +def sample_func(batch_size, input_dim, rng): + def f(batch_size, rng): + return np.ones([batch_size, input_dim]) + + return f + + +@pytest.fixture( + params=[ + pytest.param(sc, id=sc[0].__name__) + for sc in [ + (RandomPolicy, dict(batch_size="batch_size", sample_func="sample_func")), + ] + ], + name="policy", +) +def fixture_policy(request) -> Policy: + """Policies that only allow univariate inputs need to be handled separately.""" + params = {} + for key in request.param[1]: + params[key] = request.getfixturevalue(request.param[1][key]) + return request.param[0](**params) + + +def test_policy_shapes(policy, batch_size, rng, input_dim, bq_state, bq_state_no_data): + + # bq state contains data + assert policy(bq_state=bq_state, rng=rng).shape == (batch_size, input_dim) + + # bq state contains no data yet + assert policy(bq_state=bq_state_no_data, rng=rng).shape == (batch_size, input_dim) + + +# Tests specific to VanDerCorputPolicy start here + + +def test_van_der_corput_shapes(batch_size, rng): + """This is the same test as test_policies_shapes but for 1d only.""" + measure = LebesgueMeasure(domain=(0, 1)) + policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size) + + # bq state contains no data yet + bq_state_no_data = BQState(measure=measure, kernel=ExpQuad(input_shape=(1,))) + assert policy(bq_state=bq_state_no_data, rng=rng).shape == (batch_size, 1) + + # bq state contains data + nevals = 5 + bq_state = BQState( + measure=measure, + kernel=ExpQuad(input_shape=(1,)), + nodes=np.zeros([nevals, 1]), + fun_evals=np.ones(nevals), + ) + assert policy(bq_state=bq_state, rng=rng).shape == (batch_size, 1) def test_van_der_corput_multi_d_error(): From 0898beb1a63b03e7ee8688477a939f359ff9f10e Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Wed, 23 Nov 2022 16:12:47 +0100 Subject: [PATCH 10/19] parametrize policy tests --- tests/test_quad/test_bayesian_quadrature.py | 6 +- tests/test_quad/test_policy.py | 135 +++++++++++--------- 2 files changed, 74 insertions(+), 67 deletions(-) diff --git a/tests/test_quad/test_bayesian_quadrature.py b/tests/test_quad/test_bayesian_quadrature.py index 61a2da4e7..58703d230 100644 --- a/tests/test_quad/test_bayesian_quadrature.py +++ b/tests/test_quad/test_bayesian_quadrature.py @@ -150,10 +150,6 @@ def test_integrate_wrong_input(bq, bq_no_policy, data): @pytest.mark.parametrize("rng", [np.random.default_rng(42), 42]) def test_integrate_runs_with_integer_rng(bq, data, rng): - # rng is a generator - nodes, fun_evals, fun = data - bq.integrate(fun=fun, nodes=None, fun_evals=None, rng=rng) - - # rng is a seed + # make sure integrate runs with both a rn generator and a seed. nodes, fun_evals, fun = data bq.integrate(fun=fun, nodes=None, fun_evals=None, rng=rng) diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index 5528a921d..449f50aae 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -1,5 +1,10 @@ """Basic tests for BQ policies.""" + +# New policies need to be added to the fixtures 'policy_name' and 'policy_params' +# and 'policy'. + + import numpy as np import pytest @@ -14,82 +19,88 @@ def batch_size(): return 3 -@pytest.fixture -def bq_state_no_data(input_dim): - return BQState( - measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)), - kernel=ExpQuad(input_shape=(input_dim,)), - ) - - -@pytest.fixture -def bq_state(input_dim): - nevals = 5 - return BQState( - measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)), - kernel=ExpQuad(input_shape=(input_dim,)), - nodes=np.zeros([nevals, input_dim]), - fun_evals=np.ones(nevals), - ) - - -@pytest.fixture -def sample_func(batch_size, input_dim, rng): - def f(batch_size, rng): - return np.ones([batch_size, input_dim]) - - return f - - @pytest.fixture( params=[ - pytest.param(sc, id=sc[0].__name__) - for sc in [ - (RandomPolicy, dict(batch_size="batch_size", sample_func="sample_func")), - ] - ], - name="policy", + pytest.param(name, id=name) for name in ["RandomPolicy", "VanDerCorputPolicy"] + ] ) -def fixture_policy(request) -> Policy: - """Policies that only allow univariate inputs need to be handled separately.""" - params = {} - for key in request.param[1]: - params[key] = request.getfixturevalue(request.param[1][key]) - return request.param[0](**params) +def policy_name(request): + return request.param -def test_policy_shapes(policy, batch_size, rng, input_dim, bq_state, bq_state_no_data): +@pytest.fixture +def policy_params(policy_name, input_dim, batch_size, rng): + def _get_bq_states(ndim): + nevals = 5 + bq_state_no_data = BQState( + measure=LebesgueMeasure(input_dim=ndim, domain=(0, 1)), + kernel=ExpQuad(input_shape=(ndim,)), + ) + bq_state = BQState( + measure=LebesgueMeasure(input_dim=ndim, domain=(0, 1)), + kernel=ExpQuad(input_shape=(ndim,)), + nodes=np.zeros([nevals, ndim]), + fun_evals=np.ones(nevals), + ) + return bq_state, bq_state_no_data + + params = dict(name=policy_name, ndim=input_dim) + params["bq_state"], params["bq_state_no_data"] = _get_bq_states(input_dim) + + if policy_name == "RandomPolicy": + input_params = dict( + batch_size=batch_size, + sample_func=lambda batch_size, rng: np.ones([batch_size, input_dim]), + ) + elif policy_name == "VanDerCorputPolicy": + # Since VanDerCorputPolicy can only produce univariate nodes, this overrides + # input_dim = 1 for all tests. This is a bit cheap, but pytest parametrization + # is convoluted enough. + input_params = dict( + batch_size=batch_size, + measure=LebesgueMeasure(input_dim=1, domain=(0, 1)), + ) + params["bq_state"], params["bq_state_no_data"] = _get_bq_states(1) + params["ndim"] = 1 + else: + raise NotImplementedError + + params["input_params"] = input_params + + return params + + +@pytest.fixture() +def policy(policy_params): + name = policy_params.pop("name") + input_params = policy_params.pop("input_params") + + if name == "RandomPolicy": + return RandomPolicy(**input_params), policy_params + elif name == "VanDerCorputPolicy": + return VanDerCorputPolicy(**input_params), policy_params + else: + raise NotImplementedError + + +# Tests shared by all policies start here. + + +def test_policy_shapes(policy, batch_size, rng): + policy, params = policy + bq_state, bq_state_no_data = params["bq_state"], params["bq_state_no_data"] + ndim = params["ndim"] # bq state contains data - assert policy(bq_state=bq_state, rng=rng).shape == (batch_size, input_dim) + assert policy(bq_state=bq_state, rng=rng).shape == (batch_size, ndim) # bq state contains no data yet - assert policy(bq_state=bq_state_no_data, rng=rng).shape == (batch_size, input_dim) + assert policy(bq_state=bq_state_no_data, rng=rng).shape == (batch_size, ndim) # Tests specific to VanDerCorputPolicy start here -def test_van_der_corput_shapes(batch_size, rng): - """This is the same test as test_policies_shapes but for 1d only.""" - measure = LebesgueMeasure(domain=(0, 1)) - policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size) - - # bq state contains no data yet - bq_state_no_data = BQState(measure=measure, kernel=ExpQuad(input_shape=(1,))) - assert policy(bq_state=bq_state_no_data, rng=rng).shape == (batch_size, 1) - - # bq state contains data - nevals = 5 - bq_state = BQState( - measure=measure, - kernel=ExpQuad(input_shape=(1,)), - nodes=np.zeros([nevals, 1]), - fun_evals=np.ones(nevals), - ) - assert policy(bq_state=bq_state, rng=rng).shape == (batch_size, 1) - - def test_van_der_corput_multi_d_error(): """Check that van der Corput policy fails in dimensions higher than one.""" wrong_dimension = 2 From 33bb4a6d47759c28e4dd300a5255901c02d0f2ff Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Wed, 23 Nov 2022 16:16:47 +0100 Subject: [PATCH 11/19] resolving one pylint error --- tests/test_quad/test_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index 449f50aae..291b77392 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -10,7 +10,7 @@ from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure from probnum.quad.solvers import BQState -from probnum.quad.solvers.policies import Policy, RandomPolicy, VanDerCorputPolicy +from probnum.quad.solvers.policies import RandomPolicy, VanDerCorputPolicy from probnum.randprocs.kernels import ExpQuad From eb84d0e8f63f7ff28f5df1ccd067216f147da9ef Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Wed, 23 Nov 2022 16:27:26 +0100 Subject: [PATCH 12/19] fixing interface --- src/probnum/quad/solvers/policies/_policy.py | 5 +++-- src/probnum/quad/solvers/policies/_random_policy.py | 9 +++++---- .../quad/solvers/policies/_van_der_corput_policy.py | 7 ++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/probnum/quad/solvers/policies/_policy.py b/src/probnum/quad/solvers/policies/_policy.py index d5342b565..0e0cfbf41 100644 --- a/src/probnum/quad/solvers/policies/_policy.py +++ b/src/probnum/quad/solvers/policies/_policy.py @@ -6,6 +6,7 @@ import numpy as np from probnum.quad.solvers._bq_state import BQState +from probnum.typing import IntLike # pylint: disable=too-few-public-methods, fixme @@ -19,8 +20,8 @@ class Policy(abc.ABC): Size of batch of nodes when calling the policy once. """ - def __init__(self, batch_size: int) -> None: - self.batch_size = batch_size + def __init__(self, batch_size: IntLike) -> None: + self.batch_size = int(batch_size) @abc.abstractmethod def __call__( diff --git a/src/probnum/quad/solvers/policies/_random_policy.py b/src/probnum/quad/solvers/policies/_random_policy.py index 7423bba53..3884e2c9f 100644 --- a/src/probnum/quad/solvers/policies/_random_policy.py +++ b/src/probnum/quad/solvers/policies/_random_policy.py @@ -5,6 +5,7 @@ import numpy as np from probnum.quad.solvers._bq_state import BQState +from probnum.typing import IntLike from ._policy import Policy @@ -16,18 +17,18 @@ class RandomPolicy(Policy): Parameters ---------- + batch_size + Size of batch of nodes when calling the policy once. sample_func The sample function. Needs to have the following interface: `sample_func(batch_size: int, rng: np.random.Generator)` and return an array of - shape (batch_size, n_dim). - batch_size - Size of batch of nodes when calling the policy once. + shape (batch_size, input_dim). """ def __init__( self, + batch_size: IntLike, sample_func: Callable, - batch_size: int, ) -> None: super().__init__(batch_size=batch_size) self.sample_func = sample_func diff --git a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py index 1218aadbc..2010b0a92 100644 --- a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py +++ b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py @@ -6,6 +6,7 @@ from probnum.quad.integration_measures import IntegrationMeasure from probnum.quad.solvers._bq_state import BQState +from probnum.typing import IntLike from ._policy import Policy @@ -22,17 +23,17 @@ class VanDerCorputPolicy(Policy): Parameters ---------- - measure - The integration measure with finite domain. batch_size Size of batch of nodes when calling the policy once. + measure + The integration measure with finite domain. References -------- .. [1] https://en.wikipedia.org/wiki/Van_der_Corput_sequence """ - def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None: + def __init__(self, batch_size: IntLike, measure: IntegrationMeasure) -> None: super().__init__(batch_size=batch_size) if int(measure.input_dim) > 1: From 8ab844c716061cf924fa643d3075258e1c715b07 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Wed, 23 Nov 2022 16:56:21 +0100 Subject: [PATCH 13/19] api fix --- src/probnum/quad/solvers/_bayesian_quadrature.py | 6 +++--- tests/test_quad/test_policy.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/probnum/quad/solvers/_bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py index b49ab4295..13247d127 100644 --- a/src/probnum/quad/solvers/_bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -147,9 +147,9 @@ def from_problem( # require an acquisition loop. The error handling is done in ``integrate``. pass elif policy == "bmc": - policy = RandomPolicy(measure.sample, batch_size=batch_size) + policy = RandomPolicy(batch_size, measure.sample) elif policy == "vdc": - policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size) + policy = VanDerCorputPolicy(batch_size, measure) else: raise NotImplementedError(f"The given policy ({policy}) is unknown.") @@ -249,7 +249,7 @@ def bq_iterator( break # Select new nodes via policy - new_nodes = self.policy(bq_state=bq_state, rng=rng) + new_nodes = self.policy(bq_state, rng) # Evaluate the integrand at new nodes new_fun_evals = fun(new_nodes) diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index 291b77392..7af3df16a 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -92,10 +92,10 @@ def test_policy_shapes(policy, batch_size, rng): ndim = params["ndim"] # bq state contains data - assert policy(bq_state=bq_state, rng=rng).shape == (batch_size, ndim) + assert policy(bq_state, rng).shape == (batch_size, ndim) # bq state contains no data yet - assert policy(bq_state=bq_state_no_data, rng=rng).shape == (batch_size, ndim) + assert policy(bq_state_no_data, rng).shape == (batch_size, ndim) # Tests specific to VanDerCorputPolicy start here @@ -106,7 +106,7 @@ def test_van_der_corput_multi_d_error(): wrong_dimension = 2 measure = GaussianMeasure(input_dim=wrong_dimension, mean=0.0, cov=1.0) with pytest.raises(ValueError): - VanDerCorputPolicy(measure, batch_size=1) + VanDerCorputPolicy(1, measure) @pytest.mark.parametrize("domain", [(-np.Inf, 0), (1, np.Inf), (-np.Inf, np.Inf)]) @@ -114,7 +114,7 @@ def test_van_der_corput_infinite_error(domain): """Check that van der Corput policy fails on infinite domains.""" measure = LebesgueMeasure(input_dim=1, domain=domain) with pytest.raises(ValueError): - VanDerCorputPolicy(measure, batch_size=1) + VanDerCorputPolicy(1, measure) @pytest.mark.parametrize("n", [4, 8, 16, 32, 64, 128, 256]) From 0e1008d6efcc91bde366047d6b8518a3ca1c7113 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Wed, 23 Nov 2022 17:01:16 +0100 Subject: [PATCH 14/19] fix type annotation rendering --- src/probnum/quad/solvers/policies/_policy.py | 2 ++ src/probnum/quad/solvers/policies/_random_policy.py | 2 ++ src/probnum/quad/solvers/policies/_van_der_corput_policy.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/probnum/quad/solvers/policies/_policy.py b/src/probnum/quad/solvers/policies/_policy.py index 0e0cfbf41..8507929e4 100644 --- a/src/probnum/quad/solvers/policies/_policy.py +++ b/src/probnum/quad/solvers/policies/_policy.py @@ -1,5 +1,7 @@ """Abstract base class for BQ policies.""" +from __future__ import annotations + import abc from typing import Optional diff --git a/src/probnum/quad/solvers/policies/_random_policy.py b/src/probnum/quad/solvers/policies/_random_policy.py index 3884e2c9f..84830ff3e 100644 --- a/src/probnum/quad/solvers/policies/_random_policy.py +++ b/src/probnum/quad/solvers/policies/_random_policy.py @@ -1,5 +1,7 @@ """Random policy for Bayesian Monte Carlo.""" +from __future__ import annotations + from typing import Callable, Optional import numpy as np diff --git a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py index 2010b0a92..755f64952 100644 --- a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py +++ b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py @@ -1,5 +1,7 @@ """Van der Corput points for integration on 1D intervals.""" +from __future__ import annotations + from typing import Optional import numpy as np From 328ac52d8a693f66d8bceac5206c854073b5f84c Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Thu, 24 Nov 2022 15:58:02 +0100 Subject: [PATCH 15/19] addressing pr comments --- src/probnum/quad/_bayesquad.py | 4 +-- .../quad/solvers/_bayesian_quadrature.py | 28 +++++++++++-------- src/probnum/quad/solvers/policies/_policy.py | 6 ++++ .../quad/solvers/policies/_random_policy.py | 4 +++ .../policies/_van_der_corput_policy.py | 4 +++ tests/test_quad/test_bayesian_quadrature.py | 26 ++++++++--------- tests/test_quad/test_policy.py | 7 +++++ 7 files changed, 53 insertions(+), 26 deletions(-) diff --git a/src/probnum/quad/_bayesquad.py b/src/probnum/quad/_bayesquad.py index df43c72f4..6a9832af4 100644 --- a/src/probnum/quad/_bayesquad.py +++ b/src/probnum/quad/_bayesquad.py @@ -33,7 +33,7 @@ def bayesquad( var_tol: Optional[FloatLike] = None, rel_tol: Optional[FloatLike] = None, batch_size: IntLike = 1, - rng: Optional[np.random.Generator] = np.random.default_rng(), + rng: Optional[np.random.Generator] = None, jitter: FloatLike = 1.0e-8, ) -> Tuple[Normal, BQIterInfo]: r"""Infer the solution of the uni- or multivariate integral @@ -100,7 +100,7 @@ def bayesquad( Number of new observations at each update. Defaults to 1. rng Random number generator. Used by Bayesian Monte Carlo other random sampling - policies. Optional. Default is `np.random.default_rng()`. + policies. jitter Non-negative jitter to numerically stabilise kernel matrix inversion. Defaults to 1e-8. diff --git a/src/probnum/quad/solvers/_bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py index 13247d127..aff7509e7 100644 --- a/src/probnum/quad/solvers/_bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Callable, Optional, Tuple, Union, get_args +from typing import Callable, Optional, Tuple import warnings import numpy as np @@ -203,7 +203,7 @@ def bq_iterator( bq_state: BQState, info: Optional[BQIterInfo], fun: Optional[Callable], - rng: np.random.Generator, + rng: Optional[np.random.Generator], ) -> Tuple[Normal, BQState, BQIterInfo]: """Generator that implements the iteration of the BQ method. @@ -269,7 +269,7 @@ def integrate( fun: Optional[Callable], nodes: Optional[np.ndarray], fun_evals: Optional[np.ndarray], - rng: Union[IntLike, np.random.Generator] = np.random.default_rng(), + rng: Optional[np.random.Generator] = None, ) -> Tuple[Normal, BQState, BQIterInfo]: """Integrates the function ``fun``. @@ -290,7 +290,7 @@ def integrate( *shape=(n_eval,)* -- Optional function evaluations at ``nodes`` available from the start. rng - The random number generator used for random methods, or a seed. + The random number generator used for random methods. Returns ------- @@ -302,17 +302,16 @@ def integrate( Raises ------ ValueError - If neither the integrand function (``fun``) nor integrand evaluations - (``fun_evals``) are given. + If neither the integrand function ``fun`` nor integrand evaluations + ``fun_evals`` are given. ValueError - If ``nodes`` are not given and no policy is present. + If neither ``nodes`` nor ``policy`` is given. ValueError If dimension of ``nodes`` or ``fun_evals`` is incorrect, or if their shapes do not match. + ValueError + If ``rng`` is not given but ``policy`` requires it. """ - # Get the rng - if isinstance(rng, get_args(IntLike)): - rng = np.random.default_rng(int(rng)) # no policy given: Integrate on fixed dataset. if self.policy is None: @@ -323,13 +322,20 @@ def integrate( # Use fun_evals and disregard fun if both are given if fun is not None and fun_evals is not None: warnings.warn( - "No policy available: 'fun_eval' are used instead of 'fun'." + "No policy available: 'fun_evals' are used instead of 'fun'." ) fun = None # override stopping condition as no policy is given. self.stopping_criterion = ImmediateStop() + # policy given: use policy + elif self.policy.requires_rng and rng is None: + raise ValueError( + f"The policy '{self.policy.__class__.__name__}' requires a random " + f"number generator (rng) to be given." + ) + # Check if integrand function is provided if fun is None and fun_evals is None: raise ValueError( diff --git a/src/probnum/quad/solvers/policies/_policy.py b/src/probnum/quad/solvers/policies/_policy.py index 8507929e4..8fb1a3be7 100644 --- a/src/probnum/quad/solvers/policies/_policy.py +++ b/src/probnum/quad/solvers/policies/_policy.py @@ -25,6 +25,12 @@ class Policy(abc.ABC): def __init__(self, batch_size: IntLike) -> None: self.batch_size = int(batch_size) + @property + @abc.abstractmethod + def requires_rng(self) -> bool: + """Whether the policy requires a random number generator when called.""" + raise NotImplementedError + @abc.abstractmethod def __call__( self, bq_state: BQState, rng: Optional[np.random.Generator] diff --git a/src/probnum/quad/solvers/policies/_random_policy.py b/src/probnum/quad/solvers/policies/_random_policy.py index 84830ff3e..d6f417a85 100644 --- a/src/probnum/quad/solvers/policies/_random_policy.py +++ b/src/probnum/quad/solvers/policies/_random_policy.py @@ -35,6 +35,10 @@ def __init__( super().__init__(batch_size=batch_size) self.sample_func = sample_func + @property + def requires_rng(self) -> bool: + return True + def __call__( self, bq_state: BQState, rng: Optional[np.random.Generator] ) -> np.ndarray: diff --git a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py index 755f64952..c22b93085 100644 --- a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py +++ b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py @@ -59,6 +59,10 @@ def __call__( transformed_vdc_seq = vdc_seq * (self.domain_b - self.domain_a) + self.domain_a return transformed_vdc_seq.reshape((self.batch_size, 1)) + @property + def requires_rng(self) -> bool: + return False + @staticmethod def van_der_corput_sequence( n_start: int, n_end: Optional[int] = None diff --git a/tests/test_quad/test_bayesian_quadrature.py b/tests/test_quad/test_bayesian_quadrature.py index 58703d230..0f36caf4c 100644 --- a/tests/test_quad/test_bayesian_quadrature.py +++ b/tests/test_quad/test_bayesian_quadrature.py @@ -96,7 +96,7 @@ def test_bq_from_problem_defaults(bq_no_policy, bq): (1000, 1e-5, 1e-5, LambdaStoppingCriterion), ], ) -def test_bq_from_problem_stopping_condition_assignment(max_evals, var_tol, rel_tol, t): +def test_bq_from_problem_stopping_criterion_assignment(max_evals, var_tol, rel_tol, t): bq = BayesianQuadrature.from_problem( input_dim=2, domain=(0, 1), @@ -108,6 +108,7 @@ def test_bq_from_problem_stopping_condition_assignment(max_evals, var_tol, rel_t def test_integrate_no_policy_wrong_input(bq_no_policy, data): + # The combination of inputs below is important to trigger the correct exception. nodes, fun_evals, fun = data # no nodes provided @@ -119,37 +120,36 @@ def test_integrate_no_policy_wrong_input(bq_no_policy, data): bq_no_policy.integrate(fun=fun, nodes=nodes, fun_evals=fun_evals) -def test_integrate_wrong_input(bq, bq_no_policy, data): +def test_integrate_wrong_input(bq, bq_no_policy, data, rng): + # The combination of inputs below is important to trigger the correct exception. + nodes, fun_evals, fun = data # no integrand provided with pytest.raises(ValueError): - bq.integrate(fun=None, nodes=nodes, fun_evals=None) + bq.integrate(fun=None, nodes=nodes, fun_evals=None, rng=rng) with pytest.raises(ValueError): bq_no_policy.integrate(fun=None, nodes=nodes, fun_evals=None) # wrong fun_evals shape with pytest.raises(ValueError): - bq.integrate(fun=fun, nodes=nodes, fun_evals=fun_evals[:, None]) + bq.integrate(fun=fun, nodes=nodes, fun_evals=fun_evals[:, None], rng=rng) with pytest.raises(ValueError): bq_no_policy.integrate(fun=None, nodes=nodes, fun_evals=fun_evals[:, None]) # wrong nodes shape with pytest.raises(ValueError): - bq.integrate(fun=fun, nodes=nodes[:, None], fun_evals=None) + bq.integrate(fun=fun, nodes=nodes[:, None], fun_evals=fun_evals, rng=rng) with pytest.raises(ValueError): - bq_no_policy.integrate(fun=None, nodes=nodes[:, None], fun_evals=None) + bq_no_policy.integrate(fun=None, nodes=nodes[:, None], fun_evals=fun_evals) # number of points in nodes and fun_evals do not match wrong_nodes = np.vstack([nodes, np.ones([1, nodes.shape[1]])]) with pytest.raises(ValueError): - bq.integrate(fun=fun, nodes=wrong_nodes, fun_evals=fun_evals) + bq.integrate(fun=fun, nodes=wrong_nodes, fun_evals=fun_evals, rng=rng) with pytest.raises(ValueError): bq_no_policy.integrate(fun=None, nodes=wrong_nodes, fun_evals=fun_evals) - -@pytest.mark.parametrize("rng", [np.random.default_rng(42), 42]) -def test_integrate_runs_with_integer_rng(bq, data, rng): - # make sure integrate runs with both a rn generator and a seed. - nodes, fun_evals, fun = data - bq.integrate(fun=fun, nodes=None, fun_evals=None, rng=rng) + # no rng provided but policy requires it + with pytest.raises(ValueError): + bq.integrate(fun=fun, nodes=nodes, fun_evals=fun_evals, rng=None) diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index 7af3df16a..86402dda4 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -52,6 +52,7 @@ def _get_bq_states(ndim): batch_size=batch_size, sample_func=lambda batch_size, rng: np.ones([batch_size, input_dim]), ) + params["requires_rng"] = True elif policy_name == "VanDerCorputPolicy": # Since VanDerCorputPolicy can only produce univariate nodes, this overrides # input_dim = 1 for all tests. This is a bit cheap, but pytest parametrization @@ -62,6 +63,7 @@ def _get_bq_states(ndim): ) params["bq_state"], params["bq_state_no_data"] = _get_bq_states(1) params["ndim"] = 1 + params["requires_rng"] = False else: raise NotImplementedError @@ -98,6 +100,11 @@ def test_policy_shapes(policy, batch_size, rng): assert policy(bq_state_no_data, rng).shape == (batch_size, ndim) +def test_policy_property_values(policy): + policy, params = policy + assert policy.requires_rng is params["requires_rng"] + + # Tests specific to VanDerCorputPolicy start here From 7f5ba9d56b09a75fe103b5f153059585c34d36b5 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Thu, 24 Nov 2022 16:01:44 +0100 Subject: [PATCH 16/19] small change --- src/probnum/quad/solvers/_bayesian_quadrature.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/probnum/quad/solvers/_bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py index aff7509e7..691d36a80 100644 --- a/src/probnum/quad/solvers/_bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -329,7 +329,6 @@ def integrate( # override stopping condition as no policy is given. self.stopping_criterion = ImmediateStop() - # policy given: use policy elif self.policy.requires_rng and rng is None: raise ValueError( f"The policy '{self.policy.__class__.__name__}' requires a random " From 956ccb0a27d5b67074b27ff285338ff619e0df84 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Thu, 24 Nov 2022 16:03:52 +0100 Subject: [PATCH 17/19] small change --- .../quad/solvers/policies/_van_der_corput_policy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py index c22b93085..c276f5946 100644 --- a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py +++ b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py @@ -49,6 +49,10 @@ def __init__(self, batch_size: IntLike, measure: IntegrationMeasure) -> None: self.domain_a = domain_a self.domain_b = domain_b + @property + def requires_rng(self) -> bool: + return False + def __call__( self, bq_state: BQState, rng: Optional[np.random.Generator] ) -> np.ndarray: @@ -59,10 +63,6 @@ def __call__( transformed_vdc_seq = vdc_seq * (self.domain_b - self.domain_a) + self.domain_a return transformed_vdc_seq.reshape((self.batch_size, 1)) - @property - def requires_rng(self) -> bool: - return False - @staticmethod def van_der_corput_sequence( n_start: int, n_end: Optional[int] = None From 67229a26e440861ea347abdf85da7091ed94aac8 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Thu, 24 Nov 2022 16:32:21 +0100 Subject: [PATCH 18/19] fixing tests --- tests/test_quad/test_bayesquad/test_bq.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_quad/test_bayesquad/test_bq.py b/tests/test_quad/test_bayesquad/test_bq.py index 37d08a200..9578c8893 100644 --- a/tests/test_quad/test_bayesquad/test_bq.py +++ b/tests/test_quad/test_bayesquad/test_bq.py @@ -144,7 +144,7 @@ def test_integral_values_sin_lebesgue( @pytest.mark.parametrize("input_dim", [2, 3, 4]) @pytest.mark.parametrize("num_data", [1]) # pylint: disable=invalid-name -def test_integral_values_kernel_translate(kernel, measure, input_dim, x): +def test_integral_values_kernel_translate(kernel, measure, input_dim, x, rng): """Test numerical integration of kernel translates.""" kernel_embedding = KernelEmbedding(kernel, measure) # pylint: disable=cell-var-from-loop @@ -158,6 +158,7 @@ def test_integral_values_kernel_translate(kernel, measure, input_dim, x): var_tol=1e-8, max_evals=1000, batch_size=50, + rng=rng, ) true_integral = kernel_embedding.kernel_mean(np.atleast_2d(translate_point)) np.testing.assert_almost_equal(bq_integral.mean, true_integral, decimal=2) @@ -179,13 +180,13 @@ def test_no_domain_or_measure_raises_error(input_dim): @pytest.mark.parametrize("input_dim", [1]) @pytest.mark.parametrize("measure_name", ["lebesgue"]) -def test_domain_ignored_if_lebesgue(input_dim, measure): +def test_domain_ignored_if_lebesgue(input_dim, measure, rng): domain = (0, 1) fun = lambda x: np.reshape(x, (x.shape[0],)) # standard BQ bq_integral, _ = bayesquad( - fun=fun, input_dim=input_dim, domain=domain, measure=measure + fun=fun, input_dim=input_dim, domain=domain, measure=measure, rng=rng ) assert isinstance(bq_integral, Normal) @@ -199,7 +200,7 @@ def test_domain_ignored_if_lebesgue(input_dim, measure): assert isinstance(bq_integral, Normal) -def test_zero_function_gives_zero_variance_with_mle(): +def test_zero_function_gives_zero_variance_with_mle(rng): """Test that BQ variance is zero for zero function when MLE is used to set the scale parameter.""" input_dim = 1 @@ -209,7 +210,7 @@ def test_zero_function_gives_zero_variance_with_mle(): fun_evals = fun(nodes) bq_integral1, _ = bayesquad( - fun=fun, input_dim=input_dim, domain=domain, scale_estimation="mle" + fun=fun, input_dim=input_dim, domain=domain, scale_estimation="mle", rng=rng ) bq_integral2, _ = bayesquad_from_data( nodes=nodes, fun_evals=fun_evals, domain=domain, scale_estimation="mle" From 7c0839d8874f048fd84931909e102e48d64100d8 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Thu, 24 Nov 2022 16:44:52 +0100 Subject: [PATCH 19/19] fixing doctest --- src/probnum/quad/_bayesquad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/probnum/quad/_bayesquad.py b/src/probnum/quad/_bayesquad.py index 6a9832af4..6ddf8364e 100644 --- a/src/probnum/quad/_bayesquad.py +++ b/src/probnum/quad/_bayesquad.py @@ -145,9 +145,9 @@ def bayesquad( >>> input_dim = 1 >>> domain = (0, 1) - >>> def f(x): + >>> def fun(x): ... return x.reshape(-1, ) - >>> F, info = bayesquad(fun=f, input_dim=input_dim, domain=domain) + >>> F, info = bayesquad(fun, input_dim, domain=domain, rng=np.random.default_rng(0)) >>> print(F.mean) 0.5 """