Skip to content

Commit be1fa10

Browse files
committed
some more minor changes
1 parent 7258d32 commit be1fa10

File tree

2 files changed

+62
-32
lines changed

2 files changed

+62
-32
lines changed

src/probnum/quad/_bayesquad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def multilevel_bayesquad_from_data(
332332
one element is provided, it is inferred that the same nodes ``nodes[0]`` are
333333
used on every level.
334334
fun_diff_evals
335-
Tuple of length :math:`L+1` containing the evaluations of :math:`f_l - f_{l-1}`
335+
Tuple of length :math:`L+1` containing the evaluations of :math:`f_l - f_{l-1}`
336336
for each level at the nodes provided in ``nodes``. Each element must be a
337337
shape=(n_eval,) ``np.ndarray``. The zeroth element contains the evaluations of
338338
:math:`f_0`.
@@ -382,6 +382,7 @@ def multilevel_bayesquad_from_data(
382382
The tuple of kernels provided by the ``kernels`` parameter must contain distinct
383383
kernel instances, i.e., ``kernels[i] is kernel[j]`` must return ``False`` for any
384384
:math:`i\neq j`.
385+
385386
References
386387
----------
387388
.. [1] Li, K., et al., Multilevel Bayesian quadrature, AISTATS, 2023.

tests/test_quad/test_bayesquad/test_bq.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -223,53 +223,84 @@ def test_zero_function_gives_zero_variance_with_mle(rng):
223223

224224

225225
def test_multilevel_bayesquad_from_data_output_types_and_shapes(kernel, measure, rng):
226-
"""Test that inputs to multilevel BQ are handled properly."""
227-
n_level = 3
226+
"""Test correct output for different inputs to multilevel BQ."""
227+
228+
# full set of nodes
228229
ns_1 = (3, 7, 2)
229-
fun_diff_evals_1 = tuple(np.zeros(ns_1[l]) for l in range(n_level))
230-
nodes_full = tuple(measure.sample((ns_1[l]), rng=rng) for l in range(n_level))
230+
n_level_1 = len(ns_1)
231+
fun_diff_evals_1 = tuple(np.zeros(ns_1[l]) for l in range(n_level_1))
232+
nodes_full = tuple(measure.sample((ns_1[l]), rng=rng) for l in range(n_level_1))
231233

234+
# i) default kernel
232235
F, infos = multilevel_bayesquad_from_data(
233236
nodes=nodes_full,
234237
fun_diff_evals=fun_diff_evals_1,
235238
measure=measure,
236239
)
237240
assert isinstance(F, Normal)
238-
assert len(infos) == n_level
239-
# Only one set of nodes
240-
kernels_1 = tuple(copy.deepcopy(kernel) for l in range(n_level))
241-
ns_2 = (7, 7, 7)
242-
fun_diff_evals_2 = n_level * (np.zeros((ns_2[0],)),)
243-
kernels_full = n_level * (kernel,)
241+
assert len(infos) == n_level_1
242+
243+
# ii) full kernel list
244+
kernels_full_1 = tuple(copy.deepcopy(kernel) for _ in range(n_level_1))
245+
F, infos = multilevel_bayesquad_from_data(
246+
nodes=nodes_full,
247+
fun_diff_evals=fun_diff_evals_1,
248+
kernels=kernels_full_1,
249+
measure=measure,
250+
)
251+
assert isinstance(F, Normal)
252+
assert len(infos) == n_level_1
253+
254+
# one set of nodes
255+
n_level_2 = 3
256+
ns_2 = n_level_2 * (7,)
257+
fun_diff_evals_2 = tuple(np.zeros(ns_2[l]) for l in range(n_level_2))
244258
nodes_1 = (measure.sample(n_sample=ns_2[0], rng=rng),)
259+
260+
# i) default kernel
261+
F, infos = multilevel_bayesquad_from_data(
262+
nodes=nodes_1,
263+
fun_diff_evals=fun_diff_evals_2,
264+
measure=measure,
265+
)
266+
assert isinstance(F, Normal)
267+
assert len(infos) == n_level_2
268+
269+
# ii) full kernel list
270+
kernels_full_2 = tuple(copy.deepcopy(kernel) for _ in range(n_level_2))
245271
F, infos = multilevel_bayesquad_from_data(
246272
nodes=nodes_1,
247273
fun_diff_evals=fun_diff_evals_2,
248-
kernels=kernels_full,
274+
kernels=kernels_full_2,
249275
measure=measure,
250276
)
251277
assert isinstance(F, Normal)
252-
assert len(infos) == n_level
278+
assert len(infos) == n_level_2
253279

254280

255281
def test_multilevel_bayesquad_from_data_wrong_inputs(kernel, measure, rng):
256282
"""Tests that wrong number inputs to multilevel BQ throw errors."""
257-
n_level = 5
258283
ns = (3, 7, 11)
259-
nodes_1 = (measure.sample(n_sample=ns[0], rng=rng),)
260-
fun_diff_evals = n_level * (np.zeros((ns[0],)),)
261-
kernels = (kernel, kernel)
284+
n_level = len(ns)
285+
fun_diff_evals = tuple(np.zeros(ns[l]) for l in range(n_level))
286+
287+
# number of nodes does not match the number of fun evals
288+
wrong_n_nodes = 2
289+
nodes_2 = tuple(measure.sample((ns[l]), rng=rng) for l in range(wrong_n_nodes))
262290
with pytest.raises(ValueError):
263-
_, _ = multilevel_bayesquad_from_data(
264-
nodes=nodes_1,
291+
multilevel_bayesquad_from_data(
292+
nodes=nodes_2,
265293
fun_diff_evals=fun_diff_evals,
266-
kernels=kernels,
267294
measure=measure,
268295
)
269-
nodes_2 = tuple(measure.sample((ns[l]), rng=rng) for l in range(2))
296+
297+
# number of kernels does not match number of fun evals
298+
wrong_n_kernels = 2
299+
kernels = tuple(copy.deepcopy(kernel) for _ in range(wrong_n_kernels))
300+
nodes_1 = (measure.sample(n_sample=ns[0], rng=rng),)
270301
with pytest.raises(ValueError):
271-
_, _ = multilevel_bayesquad_from_data(
272-
nodes=nodes_2,
302+
multilevel_bayesquad_from_data(
303+
nodes=nodes_1,
273304
fun_diff_evals=fun_diff_evals,
274305
kernels=kernels,
275306
measure=measure,
@@ -279,18 +310,16 @@ def test_multilevel_bayesquad_from_data_wrong_inputs(kernel, measure, rng):
279310
def test_multilevel_bayesquad_from_data_equals_bq_with_trivial_data_1d():
280311
"""Test that multilevel BQ equals BQ when all but one level are given non-zero
281312
function evaluations for 1D data."""
282-
input_dim = 1
283313
n_level = 5
284314
domain = (0, 3.3)
285-
nodes = ()
286-
nodes = [np.linspace(0, 1, 2 * l + 1)[:, None] for l in range(n_level)]
315+
nodes = tuple(np.linspace(0, 1, 2 * l + 1)[:, None] for l in range(n_level))
287316
for i in range(n_level):
288317
jitter = 1e-5 * (i + 1.0)
289-
fun_diff_evals = [np.zeros(shape=(len(xs),)) for xs in nodes]
290318
fun_evals = nodes[i][:, 0] ** (2 + 0.3 * i) + 1.2
319+
fun_diff_evals = [np.zeros(shape=(len(xs),)) for xs in nodes]
291320
fun_diff_evals[i] = fun_evals
292321
mlbq_integral, _ = multilevel_bayesquad_from_data(
293-
nodes=tuple(nodes),
322+
nodes=nodes,
294323
fun_diff_evals=tuple(fun_diff_evals),
295324
domain=domain,
296325
options=dict(jitter=jitter),
@@ -312,16 +341,16 @@ def test_multilevel_bayesquad_from_data_equals_bq_with_trivial_data_2d():
312341
n_level = 5
313342
measure = GaussianMeasure(np.full((input_dim,), 0.2), cov=0.6 * np.eye(input_dim))
314343
_gh = gauss_hermite_tensor
315-
nodes = [
344+
nodes = tuple(
316345
_gh(l + 1, input_dim, measure.mean, measure.cov)[0] for l in range(n_level)
317-
]
346+
)
318347
for i in range(n_level):
319348
jitter = 1e-5 * (i + 1.0)
320-
fun_diff_evals = [np.zeros(shape=(len(xs),)) for xs in nodes]
321349
fun_evals = np.sin(nodes[i][:, 0] * i) + (i + 1.0) * np.cos(nodes[i][:, 1])
350+
fun_diff_evals = [np.zeros(shape=(len(xs),)) for xs in nodes]
322351
fun_diff_evals[i] = fun_evals
323352
mlbq_integral, _ = multilevel_bayesquad_from_data(
324-
nodes=tuple(nodes),
353+
nodes=nodes,
325354
fun_diff_evals=tuple(fun_diff_evals),
326355
measure=measure,
327356
options=dict(jitter=jitter),

0 commit comments

Comments
 (0)