Skip to content

Commit 7b27f14

Browse files
committed
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into dev
2 parents e4e6da4 + 55d51df commit 7b27f14

File tree

8 files changed

+267
-14
lines changed

8 files changed

+267
-14
lines changed

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .expected_calibration_error import expected_calibration_error
55
from .classifier_two_sample_test import classifier_two_sample_test
66
from .model_misspecification import bootstrap_comparison, summary_space_comparison
7+
from .sbc import log_gamma
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
from collections.abc import Mapping, Sequence
2+
3+
import numpy as np
4+
from scipy.stats import binom
5+
6+
from ...utils.dict_utils import dicts_to_arrays
7+
8+
9+
def log_gamma(
10+
estimates: Mapping[str, np.ndarray] | np.ndarray,
11+
targets: Mapping[str, np.ndarray] | np.ndarray,
12+
variable_keys: Sequence[str] = None,
13+
variable_names: Sequence[str] = None,
14+
num_null_draws: int = 1000,
15+
quantile: float = 0.05,
16+
):
17+
"""
18+
Compute the log gamma discrepancy statistic, see [1] for additional information.
19+
Log gamma is log(gamma/gamma_null), where gamma_null is the 5th percentile of the
20+
null distribution under uniformity of ranks.
21+
That is, if adopting a hypothesis testing framework,then log_gamma < 0 implies
22+
a rejection of the hypothesis of uniform ranks at the 5% level.
23+
This diagnostic is typically more sensitive than the Kolmogorov-Smirnoff test or
24+
ChiSq test.
25+
26+
[1] Martin Modrák. Angie H. Moon. Shinyoung Kim. Paul Bürkner. Niko Huurre.
27+
Kateřina Faltejsková. Andrew Gelman. Aki Vehtari.
28+
"Simulation-Based Calibration Checking for Bayesian Computation:
29+
The Choice of Test Quantities Shapes Sensitivity."
30+
Bayesian Anal. 20 (2) 461 - 488, June 2025. https://doi.org/10.1214/23-BA1404
31+
32+
Parameters
33+
----------
34+
estimates : np.ndarray of shape (num_datasets, num_draws, num_variables)
35+
The random draws from the approximate posteriors over ``num_datasets``
36+
targets : np.ndarray of shape (num_datasets, num_variables)
37+
The corresponding ground-truth values sampled from the prior
38+
variable_keys : Sequence[str], optional (default = None)
39+
Select keys from the dictionaries provided in estimates and targets.
40+
By default, select all keys.
41+
variable_names : Sequence[str], optional (default = None)
42+
Optional variable names to show in the output.
43+
quantile : float in (0, 1), optional, default 0.05
44+
The quantile from the null distribution to be used as a threshold.
45+
A lower quantile increases sensitivity to deviations from uniformity.
46+
47+
Returns
48+
-------
49+
result : dict
50+
Dictionary containing:
51+
52+
- "values" : float or np.ndarray
53+
The log gamma values per variable
54+
- "metric_name" : str
55+
The name of the metric ("Log Gamma").
56+
- "variable_names" : str
57+
The (inferred) variable names.
58+
"""
59+
samples = dicts_to_arrays(
60+
estimates=estimates,
61+
targets=targets,
62+
variable_keys=variable_keys,
63+
variable_names=variable_names,
64+
)
65+
66+
num_ranks = samples["estimates"].shape[0]
67+
num_post_draws = samples["estimates"].shape[1]
68+
69+
# rank statistics
70+
ranks = np.sum(samples["estimates"] < samples["targets"][:, None], axis=1)
71+
72+
# null distribution and threshold
73+
null_distribution = gamma_null_distribution(num_ranks, num_post_draws, num_null_draws)
74+
null_quantile = np.quantile(null_distribution, quantile)
75+
76+
# compute log gamma for each parameter
77+
log_gammas = np.empty(ranks.shape[-1])
78+
79+
for i in range(ranks.shape[-1]):
80+
gamma = gamma_discrepancy(ranks[:, i], num_post_draws=num_post_draws)
81+
log_gammas[i] = np.log(gamma / null_quantile)
82+
83+
output = {
84+
"values": log_gammas,
85+
"metric_name": "Log Gamma",
86+
"variable_names": samples["estimates"].variable_names,
87+
}
88+
89+
return output
90+
91+
92+
def gamma_null_distribution(num_ranks: int, num_post_draws: int = 1000, num_null_draws: int = 1000) -> np.ndarray:
93+
"""
94+
Computes the distribution of expected gamma values under uniformity of ranks.
95+
96+
Parameters
97+
----------
98+
num_ranks : int
99+
Number of ranks to use for each gamma.
100+
num_post_draws : int, optional, default 1000
101+
Number of posterior draws that were used to calculate the rank distribution.
102+
num_null_draws : int, optional, default 1000
103+
Number of returned gamma values under uniformity of ranks.
104+
105+
Returns
106+
-------
107+
result : np.ndarray
108+
Array of shape (num_null_draws,) containing gamma values under uniformity of ranks.
109+
"""
110+
z_i = np.arange(1, num_post_draws + 2) / (num_post_draws + 1)
111+
gamma = np.empty(num_null_draws)
112+
113+
# loop non-vectorized to reduce memory footprint
114+
for i in range(num_null_draws):
115+
u = np.random.uniform(size=num_ranks)
116+
F_z = np.mean(u[:, None] < z_i, axis=0)
117+
bin_1 = binom.cdf(num_ranks * F_z, num_ranks, z_i)
118+
bin_2 = 1 - binom.cdf(num_ranks * F_z - 1, num_ranks, z_i)
119+
120+
gamma[i] = 2 * np.min(np.minimum(bin_1, bin_2))
121+
122+
return gamma
123+
124+
125+
def gamma_discrepancy(ranks: np.ndarray, num_post_draws: int = 100) -> float:
126+
"""
127+
Quantifies deviation from uniformity by the likelihood of observing the
128+
most extreme point on the empirical CDF of the given rank distribution
129+
according to [1] (equation 7).
130+
131+
[1] Martin Modrák. Angie H. Moon. Shinyoung Kim. Paul Bürkner. Niko Huurre.
132+
Kateřina Faltejsková. Andrew Gelman. Aki Vehtari.
133+
"Simulation-Based Calibration Checking for Bayesian Computation:
134+
The Choice of Test Quantities Shapes Sensitivity."
135+
Bayesian Anal. 20 (2) 461 - 488, June 2025. https://doi.org/10.1214/23-BA1404
136+
137+
Parameters
138+
----------
139+
ranks : array of shape (num_ranks,)
140+
Empirical rank distribution
141+
num_post_draws : int, optional, default 100
142+
Number of posterior draws used to generate ranks.
143+
144+
Returns
145+
-------
146+
result : float
147+
Gamma discrepancy values for each parameter.
148+
"""
149+
num_ranks = len(ranks)
150+
151+
# observed count of ranks smaller than i
152+
R_i = np.array([sum(ranks < i) for i in range(1, num_post_draws + 2)])
153+
154+
# expected proportion of ranks smaller than i
155+
z_i = np.arange(1, num_post_draws + 2) / (num_post_draws + 1)
156+
157+
bin_1 = binom.cdf(R_i, num_ranks, z_i)
158+
bin_2 = 1 - binom.cdf(R_i - 1, num_ranks, z_i)
159+
160+
# likelihood of obtaining the most extreme point on the empirical CDF
161+
# if the rank distribution was indeed uniform
162+
return float(2 * np.min(np.minimum(bin_1, bin_2)))

bayesflow/networks/standardization/standardization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def moving_std(self, index: int) -> Tensor:
4040
"""
4141
return keras.ops.where(
4242
self.moving_m2[index] > 0,
43-
keras.ops.sqrt(self.moving_m2[index] / self.count),
43+
keras.ops.sqrt(self.moving_m2[index] / self.count[index]),
4444
1.0,
4545
)
4646

@@ -53,7 +53,7 @@ def build(self, input_shape: Shape):
5353
self.moving_m2 = [
5454
self.add_weight(shape=(shape[-1],), initializer="zeros", trainable=False) for shape in flattened_shapes
5555
]
56-
self.count = self.add_weight(shape=(), initializer="zeros", trainable=False)
56+
self.count = [self.add_weight(shape=(), initializer="zeros", trainable=False) for _ in flattened_shapes]
5757

5858
def call(
5959
self,
@@ -150,7 +150,7 @@ def _update_moments(self, x: Tensor, index: int):
150150
"""
151151

152152
reduce_axes = tuple(range(x.ndim - 1))
153-
batch_count = keras.ops.cast(keras.ops.shape(x)[0], self.count.dtype)
153+
batch_count = keras.ops.cast(keras.ops.prod(keras.ops.shape(x)[:-1]), self.count[index].dtype)
154154

155155
# Compute batch mean and M2 per feature
156156
batch_mean = keras.ops.mean(x, axis=reduce_axes)
@@ -159,7 +159,7 @@ def _update_moments(self, x: Tensor, index: int):
159159
# Read current totals
160160
mean = self.moving_mean[index]
161161
m2 = self.moving_m2[index]
162-
count = self.count
162+
count = self.count[index]
163163

164164
total_count = count + batch_count
165165
delta = batch_mean - mean
@@ -169,4 +169,4 @@ def _update_moments(self, x: Tensor, index: int):
169169

170170
self.moving_mean[index].assign(new_mean)
171171
self.moving_m2[index].assign(new_m2)
172-
self.count.assign(total_count)
172+
self.count[index].assign(total_count)

bayesflow/scores/multivariate_normal_score.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
8282
"""
8383
diff = x - mean
8484

85-
# Calculate covariance from Cholesky factors
86-
covariance = keras.ops.matmul(
87-
cov_chol,
88-
keras.ops.swapaxes(cov_chol, -2, -1),
85+
# Calculate precision from Cholesky factors of covariance matrix
86+
cov_chol_inv = keras.ops.inv(cov_chol)
87+
precision = keras.ops.matmul(
88+
keras.ops.swapaxes(cov_chol_inv, -2, -1),
89+
cov_chol_inv,
8990
)
90-
precision = keras.ops.inv(covariance)
91-
log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part
91+
92+
# Compute log determinant, exploiting Cholesky factors
93+
log_det_covariance = keras.ops.log(keras.ops.prod(keras.ops.diagonal(cov_chol, axis1=1, axis2=2), axis=1)) * 2
9294

9395
# Compute the quadratic term in the exponential of the multivariate Gaussian
9496
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)

tests/test_approximators/test_approximator_standardization/test_approximator_standardization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset
88
approximator.build(data_shapes)
99
for layer in approximator.standardize_layers.values():
1010
assert layer.built
11-
assert layer.count == 0
11+
for count in layer.count:
12+
assert count == 0.0
1213
approximator.compute_metrics(**train_dataset[0])
1314

1415
keras.saving.save_model(approximator, tmp_path / "model.keras")

tests/test_approximators/test_build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ def test_build(approximator, simulator, batch_size, adapter):
1414
approximator.build(batch_shapes)
1515
for layer in approximator.standardize_layers.values():
1616
assert layer.built
17-
assert layer.count == 0
17+
for count in layer.count:
18+
assert count == 0.0

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import numpy as np
21
import keras
2+
import numpy as np
33
import pytest
4+
from scipy.stats import binom
45

56
import bayesflow as bf
67

@@ -84,6 +85,58 @@ def test_expected_calibration_error(pred_models, true_models, model_names):
8485
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)
8586

8687

88+
def test_log_gamma(random_estimates, random_targets):
89+
out = bf.diagnostics.metrics.log_gamma(random_estimates, random_targets)
90+
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
91+
assert out["values"].shape == (num_variables(random_estimates),)
92+
assert out["metric_name"] == "Log Gamma"
93+
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
94+
95+
96+
def test_log_gamma_end_to_end():
97+
# This is a function test for simulation-based calibration.
98+
# First, we sample from a known generative process and then run SBC.
99+
# If the log gamma statistic is correctly implemented, a 95% interval should exclude
100+
# the true value 5% of the time.
101+
102+
N = 30 # number of samples
103+
S = 1000 # number of posterior draws
104+
D = 1000 # number of datasets
105+
106+
def run_sbc(N=N, S=S, D=D, bias=0):
107+
rng = np.random.default_rng()
108+
prior_draws = rng.beta(2, 2, size=D)
109+
successes = rng.binomial(N, prior_draws)
110+
111+
# Analytical posterior:
112+
# if theta ~ Beta(2, 2), then p(theta|successes) is Beta(2 + successes | 2 + N - successes).
113+
posterior_draws = rng.beta(2 + successes + bias, 2 + N - successes + bias, size=(S, D))
114+
115+
# these ranks are uniform if bias=0
116+
ranks = np.sum(posterior_draws < prior_draws, axis=0)
117+
118+
# this is the distribution of gamma under uniform ranks
119+
gamma_null = bf.diagnostics.metrics.sbc.gamma_null_distribution(D, S, num_null_draws=100)
120+
lower, upper = np.quantile(gamma_null, (0.05, 0.995))
121+
122+
# this is the empirical gamma
123+
observed_gamma = bf.diagnostics.metrics.sbc.gamma_discrepancy(ranks, num_post_draws=S)
124+
125+
in_interval = lower <= observed_gamma < upper
126+
127+
return in_interval
128+
129+
sbc_calibration = [run_sbc(N=N, S=S, D=D) for _ in range(100)]
130+
lower_expected, upper_expected = binom.ppf((0.0005, 0.9995), 100, 0.95)
131+
132+
# this test should fail with a probability of 0.1%
133+
assert lower_expected <= np.sum(sbc_calibration) <= upper_expected
134+
135+
# sbc should almost always fial for slightly biased posterior draws
136+
sbc_calibration = [run_sbc(N=N, S=S, D=D, bias=1) for _ in range(100)]
137+
assert not lower_expected <= np.sum(sbc_calibration) <= upper_expected
138+
139+
87140
def test_bootstrap_comparison_shapes():
88141
"""Test the bootstrap_comparison output shapes."""
89142
observed_samples = np.random.rand(10, 5)

tests/test_networks/test_standardization.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,39 @@ def test_nested_consistency_forward_inverse():
9191
np.testing.assert_allclose(random_input["b"], recovered["b"], atol=1e-4)
9292

9393

94+
def test_nested_accuracy_forward():
95+
from bayesflow.utils import tree_concatenate
96+
97+
# create inputs for two training passes
98+
random_input_a_1 = keras.random.normal((2, 3, 5))
99+
random_input_b_1 = keras.random.normal((4, 3))
100+
random_input_1 = {"a": random_input_a_1, "b": random_input_b_1}
101+
102+
random_input_a_2 = keras.random.normal((3, 3, 5))
103+
random_input_b_2 = keras.random.normal((3, 3))
104+
random_input_2 = {"a": random_input_a_2, "b": random_input_b_2}
105+
106+
# complete data for testing mean and std are 0 and 1
107+
random_input = tree_concatenate([random_input_1, random_input_2], axis=0)
108+
109+
layer = Standardization()
110+
111+
_ = layer(random_input_1, stage="training", forward=True)
112+
_ = layer(random_input_2, stage="training", forward=True)
113+
114+
standardized = layer(random_input, stage="inference", forward=True)
115+
standardized = keras.tree.map_structure(keras.ops.convert_to_numpy, standardized)
116+
117+
np.testing.assert_allclose(
118+
np.mean(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 0.0, atol=1e-4
119+
)
120+
np.testing.assert_allclose(
121+
np.mean(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 0.0, atol=1e-4
122+
)
123+
np.testing.assert_allclose(np.std(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 1.0, atol=1e-4)
124+
np.testing.assert_allclose(np.std(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 1.0, atol=1e-4)
125+
126+
94127
def test_transformation_type_both_sides_scale():
95128
# Fix a known covariance and mean in original (not standardized space)
96129
covariance = np.array([[1, 0.5], [0.5, 2.0]], dtype="float32")

0 commit comments

Comments
 (0)