@@ -223,53 +223,84 @@ def test_zero_function_gives_zero_variance_with_mle(rng):
223223
224224
225225def test_multilevel_bayesquad_from_data_output_types_and_shapes (kernel , measure , rng ):
226- """Test that inputs to multilevel BQ are handled properly."""
227- n_level = 3
226+ """Test correct output for different inputs to multilevel BQ."""
227+
228+ # full set of nodes
228229 ns_1 = (3 , 7 , 2 )
229- fun_diff_evals_1 = tuple (np .zeros (ns_1 [l ]) for l in range (n_level ))
230- nodes_full = tuple (measure .sample ((ns_1 [l ]), rng = rng ) for l in range (n_level ))
230+ n_level_1 = len (ns_1 )
231+ fun_diff_evals_1 = tuple (np .zeros (ns_1 [l ]) for l in range (n_level_1 ))
232+ nodes_full = tuple (measure .sample ((ns_1 [l ]), rng = rng ) for l in range (n_level_1 ))
231233
234+ # i) default kernel
232235 F , infos = multilevel_bayesquad_from_data (
233236 nodes = nodes_full ,
234237 fun_diff_evals = fun_diff_evals_1 ,
235238 measure = measure ,
236239 )
237240 assert isinstance (F , Normal )
238- assert len (infos ) == n_level
239- # Only one set of nodes
240- kernels_1 = tuple (copy .deepcopy (kernel ) for l in range (n_level ))
241- ns_2 = (7 , 7 , 7 )
242- fun_diff_evals_2 = n_level * (np .zeros ((ns_2 [0 ],)),)
243- kernels_full = n_level * (kernel ,)
241+ assert len (infos ) == n_level_1
242+
243+ # ii) full kernel list
244+ kernels_full_1 = tuple (copy .deepcopy (kernel ) for _ in range (n_level_1 ))
245+ F , infos = multilevel_bayesquad_from_data (
246+ nodes = nodes_full ,
247+ fun_diff_evals = fun_diff_evals_1 ,
248+ kernels = kernels_full_1 ,
249+ measure = measure ,
250+ )
251+ assert isinstance (F , Normal )
252+ assert len (infos ) == n_level_1
253+
254+ # one set of nodes
255+ n_level_2 = 3
256+ ns_2 = n_level_2 * (7 ,)
257+ fun_diff_evals_2 = tuple (np .zeros (ns_2 [l ]) for l in range (n_level_2 ))
244258 nodes_1 = (measure .sample (n_sample = ns_2 [0 ], rng = rng ),)
259+
260+ # i) default kernel
261+ F , infos = multilevel_bayesquad_from_data (
262+ nodes = nodes_1 ,
263+ fun_diff_evals = fun_diff_evals_2 ,
264+ measure = measure ,
265+ )
266+ assert isinstance (F , Normal )
267+ assert len (infos ) == n_level_2
268+
269+ # ii) full kernel list
270+ kernels_full_2 = tuple (copy .deepcopy (kernel ) for _ in range (n_level_2 ))
245271 F , infos = multilevel_bayesquad_from_data (
246272 nodes = nodes_1 ,
247273 fun_diff_evals = fun_diff_evals_2 ,
248- kernels = kernels_full ,
274+ kernels = kernels_full_2 ,
249275 measure = measure ,
250276 )
251277 assert isinstance (F , Normal )
252- assert len (infos ) == n_level
278+ assert len (infos ) == n_level_2
253279
254280
255281def test_multilevel_bayesquad_from_data_wrong_inputs (kernel , measure , rng ):
256282 """Tests that wrong number inputs to multilevel BQ throw errors."""
257- n_level = 5
258283 ns = (3 , 7 , 11 )
259- nodes_1 = (measure .sample (n_sample = ns [0 ], rng = rng ),)
260- fun_diff_evals = n_level * (np .zeros ((ns [0 ],)),)
261- kernels = (kernel , kernel )
284+ n_level = len (ns )
285+ fun_diff_evals = tuple (np .zeros (ns [l ]) for l in range (n_level ))
286+
287+ # number of nodes does not match the number of fun evals
288+ wrong_n_nodes = 2
289+ nodes_2 = tuple (measure .sample ((ns [l ]), rng = rng ) for l in range (wrong_n_nodes ))
262290 with pytest .raises (ValueError ):
263- _ , _ = multilevel_bayesquad_from_data (
264- nodes = nodes_1 ,
291+ multilevel_bayesquad_from_data (
292+ nodes = nodes_2 ,
265293 fun_diff_evals = fun_diff_evals ,
266- kernels = kernels ,
267294 measure = measure ,
268295 )
269- nodes_2 = tuple (measure .sample ((ns [l ]), rng = rng ) for l in range (2 ))
296+
297+ # number of kernels does not match number of fun evals
298+ wrong_n_kernels = 2
299+ kernels = tuple (copy .deepcopy (kernel ) for _ in range (wrong_n_kernels ))
300+ nodes_1 = (measure .sample (n_sample = ns [0 ], rng = rng ),)
270301 with pytest .raises (ValueError ):
271- _ , _ = multilevel_bayesquad_from_data (
272- nodes = nodes_2 ,
302+ multilevel_bayesquad_from_data (
303+ nodes = nodes_1 ,
273304 fun_diff_evals = fun_diff_evals ,
274305 kernels = kernels ,
275306 measure = measure ,
@@ -279,18 +310,16 @@ def test_multilevel_bayesquad_from_data_wrong_inputs(kernel, measure, rng):
279310def test_multilevel_bayesquad_from_data_equals_bq_with_trivial_data_1d ():
280311 """Test that multilevel BQ equals BQ when all but one level are given non-zero
281312 function evaluations for 1D data."""
282- input_dim = 1
283313 n_level = 5
284314 domain = (0 , 3.3 )
285- nodes = ()
286- nodes = [np .linspace (0 , 1 , 2 * l + 1 )[:, None ] for l in range (n_level )]
315+ nodes = tuple (np .linspace (0 , 1 , 2 * l + 1 )[:, None ] for l in range (n_level ))
287316 for i in range (n_level ):
288317 jitter = 1e-5 * (i + 1.0 )
289- fun_diff_evals = [np .zeros (shape = (len (xs ),)) for xs in nodes ]
290318 fun_evals = nodes [i ][:, 0 ] ** (2 + 0.3 * i ) + 1.2
319+ fun_diff_evals = [np .zeros (shape = (len (xs ),)) for xs in nodes ]
291320 fun_diff_evals [i ] = fun_evals
292321 mlbq_integral , _ = multilevel_bayesquad_from_data (
293- nodes = tuple ( nodes ) ,
322+ nodes = nodes ,
294323 fun_diff_evals = tuple (fun_diff_evals ),
295324 domain = domain ,
296325 options = dict (jitter = jitter ),
@@ -312,16 +341,16 @@ def test_multilevel_bayesquad_from_data_equals_bq_with_trivial_data_2d():
312341 n_level = 5
313342 measure = GaussianMeasure (np .full ((input_dim ,), 0.2 ), cov = 0.6 * np .eye (input_dim ))
314343 _gh = gauss_hermite_tensor
315- nodes = [
344+ nodes = tuple (
316345 _gh (l + 1 , input_dim , measure .mean , measure .cov )[0 ] for l in range (n_level )
317- ]
346+ )
318347 for i in range (n_level ):
319348 jitter = 1e-5 * (i + 1.0 )
320- fun_diff_evals = [np .zeros (shape = (len (xs ),)) for xs in nodes ]
321349 fun_evals = np .sin (nodes [i ][:, 0 ] * i ) + (i + 1.0 ) * np .cos (nodes [i ][:, 1 ])
350+ fun_diff_evals = [np .zeros (shape = (len (xs ),)) for xs in nodes ]
322351 fun_diff_evals [i ] = fun_evals
323352 mlbq_integral , _ = multilevel_bayesquad_from_data (
324- nodes = tuple ( nodes ) ,
353+ nodes = nodes ,
325354 fun_diff_evals = tuple (fun_diff_evals ),
326355 measure = measure ,
327356 options = dict (jitter = jitter ),
0 commit comments