Skip to content

Commit 2038d66

Browse files
Custom test quantity support for calibration_ecdf (#528)
* Custom test quantity support for calibration_ecdf * rename variable [no ci] * Consistent defaults for variable_keys/names in calibration_ecdf with test quantiles * Tests for calibration_ecdf with test_quantities * Remove redundant and simplify comments * Fix docstrings and typehints --------- Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>
1 parent 13112bc commit 2038d66

File tree

4 files changed

+75
-2
lines changed

4 files changed

+75
-2
lines changed

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections.abc import Mapping, Sequence
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
4+
import keras
45
import matplotlib.pyplot as plt
56

67
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
@@ -13,6 +14,7 @@ def calibration_ecdf(
1314
targets: Mapping[str, np.ndarray] | np.ndarray,
1415
variable_keys: Sequence[str] = None,
1516
variable_names: Sequence[str] = None,
17+
test_quantities: dict[str, Callable] = None,
1618
difference: bool = False,
1719
stacked: bool = False,
1820
rank_type: str | np.ndarray = "fractional",
@@ -78,6 +80,18 @@ def calibration_ecdf(
7880
variable_names : list or None, optional, default: None
7981
The parameter names for nice plot titles.
8082
Inferred if None. Only relevant if `stacked=False`.
83+
test_quantities : dict or None, optional, default: None
84+
A dict that maps plot titles to functions that compute
85+
test quantities based on estimate/target draws.
86+
87+
The dict keys are automatically added to ``variable_keys``
88+
and ``variable_names``.
89+
Test quantity functions are expected to accept a dict of draws with
90+
shape ``(batch_size, ...)`` as the first (typically only)
91+
positional argument and return an NumPy array of shape
92+
``(batch_size,)``.
93+
The functions do not have to deal with an additional
94+
sample dimension, as appropriate reshaping is done internally.
8195
figsize : tuple or None, optional, default: None
8296
The figure size passed to the matplotlib constructor.
8397
Inferred if None.
@@ -120,6 +134,36 @@ def calibration_ecdf(
120134
If an unknown `rank_type` is passed.
121135
"""
122136

137+
# Optionally, compute and prepend test quantities from draws
138+
if test_quantities is not None:
139+
test_quantities_estimates = {}
140+
test_quantities_targets = {}
141+
142+
for key, test_quantity_fn in test_quantities.items():
143+
# Apply test_quantity_func to ground-truths
144+
tq_targets = test_quantity_fn(data=targets)
145+
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)
146+
147+
# # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
148+
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
149+
flattened_estimates = keras.tree.map_structure(lambda t: np.reshape(t, (-1, *t.shape[2:])), estimates)
150+
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
151+
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))
152+
153+
# Add custom test quantities to variable keys and names for plotting
154+
# keys and names are set to the test_quantities dict keys
155+
test_quantities_names = list(test_quantities.keys())
156+
157+
if variable_keys is None:
158+
variable_keys = list(estimates.keys())
159+
160+
if isinstance(variable_names, list):
161+
variable_names = test_quantities_names + variable_names
162+
163+
variable_keys = test_quantities_names + variable_keys
164+
estimates = test_quantities_estimates | estimates
165+
targets = test_quantities_targets | targets
166+
123167
plot_data = prepare_plot_data(
124168
estimates=estimates,
125169
targets=targets,

bayesflow/utils/dict_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ def dicts_to_arrays(
282282
Ground-truth values corresponding to the estimates. Must match the structure and dimensionality
283283
of `estimates` in terms of first and last axis.
284284
285+
priors : dict[str, ndarray] or ndarray, optional (default = None)
286+
Prior draws. Must match the structure and dimensionality
287+
of `estimates` in terms of first and last axis.
288+
285289
dataset_ids : Sequence of integers indexing the datasets to select (default = None).
286290
By default, use all datasets.
287291

bayesflow/utils/plot_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def prepare_plot_data(
2323
figsize: tuple = None,
2424
stacked: bool = False,
2525
default_name: str = "v",
26-
) -> Mapping[str, Any]:
26+
) -> dict[str, Any]:
2727
"""
2828
Procedural wrapper that encompasses all preprocessing steps, including shape-checking, parameter name
2929
generation, layout configuration, figure initialization, and collapsing of axes.
@@ -56,6 +56,12 @@ def prepare_plot_data(
5656
Whether the plots are stacked horizontally
5757
default_name : str, optional (default = "v")
5858
The default name to use for estimates if None provided
59+
60+
Returns
61+
-------
62+
plot_data : dict[str, Any]
63+
A dictionary containing all preprocessed data and plotting objects required for visualization,
64+
including estimates, targets, variable names, figure, axes, and layout configuration.
5965
"""
6066

6167
plot_data = dicts_to_arrays(

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import bayesflow as bf
2+
import numpy as np
23
import pytest
34

45

@@ -16,6 +17,8 @@ def test_backend():
1617

1718

1819
def test_calibration_ecdf(random_estimates, random_targets, var_names):
20+
print(random_estimates, random_targets, var_names)
21+
1922
# basic functionality: automatic variable names
2023
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)
2124
assert len(out.axes) == num_variables(random_estimates)
@@ -46,6 +49,22 @@ def test_calibration_ecdf(random_estimates, random_targets, var_names):
4649
# cannot infer the variable names from an array so default names are used
4750
assert out.axes[1].title._text == "v_1"
4851

52+
# test quantities plots are shown
53+
test_quantities = {
54+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
55+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
56+
}
57+
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets, test_quantities=test_quantities)
58+
assert len(out.axes) == len(test_quantities) + num_variables(random_estimates)
59+
assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$"
60+
assert out.axes[-1].title._text == r"sigma"
61+
62+
# test plot titles changed to variable_names in case test quantities exist
63+
out = bf.diagnostics.plots.calibration_ecdf(
64+
random_estimates, random_targets, test_quantities=test_quantities, variable_names=var_names
65+
)
66+
assert out.axes[-1].title._text == r"$\sigma$"
67+
4968

5069
def test_calibration_histogram(random_estimates, random_targets):
5170
# basic functionality: automatic variable names

0 commit comments

Comments
 (0)