Skip to content

Commit 0ea79d7

Browse files
committed
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into dev
2 parents c9feff2 + 2038d66 commit 0ea79d7

File tree

12 files changed

+234
-42
lines changed

12 files changed

+234
-42
lines changed

bayesflow/adapters/transforms/broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray
117117
data[k] = np.expand_dims(data[k], axis=tuple(np.arange(0, len_diff)))
118118
elif self.expand == "right":
119119
data[k] = np.expand_dims(data[k], axis=tuple(-np.arange(1, len_diff + 1)))
120-
elif isinstance(self.expand, tuple):
120+
elif isinstance(self.expand, Sequence):
121121
if len(self.expand) is not len_diff:
122122
raise ValueError("Length of `expand` must match the length difference of the involed arrays.")
123123
data[k] = np.expand_dims(data[k], axis=self.expand)

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections.abc import Mapping, Sequence
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
4+
import keras
45
import matplotlib.pyplot as plt
56

67
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
@@ -13,6 +14,7 @@ def calibration_ecdf(
1314
targets: Mapping[str, np.ndarray] | np.ndarray,
1415
variable_keys: Sequence[str] = None,
1516
variable_names: Sequence[str] = None,
17+
test_quantities: dict[str, Callable] = None,
1618
difference: bool = False,
1719
stacked: bool = False,
1820
rank_type: str | np.ndarray = "fractional",
@@ -78,6 +80,18 @@ def calibration_ecdf(
7880
variable_names : list or None, optional, default: None
7981
The parameter names for nice plot titles.
8082
Inferred if None. Only relevant if `stacked=False`.
83+
test_quantities : dict or None, optional, default: None
84+
A dict that maps plot titles to functions that compute
85+
test quantities based on estimate/target draws.
86+
87+
The dict keys are automatically added to ``variable_keys``
88+
and ``variable_names``.
89+
Test quantity functions are expected to accept a dict of draws with
90+
shape ``(batch_size, ...)`` as the first (typically only)
91+
positional argument and return an NumPy array of shape
92+
``(batch_size,)``.
93+
The functions do not have to deal with an additional
94+
sample dimension, as appropriate reshaping is done internally.
8195
figsize : tuple or None, optional, default: None
8296
The figure size passed to the matplotlib constructor.
8397
Inferred if None.
@@ -120,6 +134,36 @@ def calibration_ecdf(
120134
If an unknown `rank_type` is passed.
121135
"""
122136

137+
# Optionally, compute and prepend test quantities from draws
138+
if test_quantities is not None:
139+
test_quantities_estimates = {}
140+
test_quantities_targets = {}
141+
142+
for key, test_quantity_fn in test_quantities.items():
143+
# Apply test_quantity_func to ground-truths
144+
tq_targets = test_quantity_fn(data=targets)
145+
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)
146+
147+
# # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
148+
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
149+
flattened_estimates = keras.tree.map_structure(lambda t: np.reshape(t, (-1, *t.shape[2:])), estimates)
150+
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
151+
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))
152+
153+
# Add custom test quantities to variable keys and names for plotting
154+
# keys and names are set to the test_quantities dict keys
155+
test_quantities_names = list(test_quantities.keys())
156+
157+
if variable_keys is None:
158+
variable_keys = list(estimates.keys())
159+
160+
if isinstance(variable_names, list):
161+
variable_names = test_quantities_names + variable_names
162+
163+
variable_keys = test_quantities_names + variable_keys
164+
estimates = test_quantities_estimates | estimates
165+
targets = test_quantities_targets | targets
166+
123167
plot_data = prepare_plot_data(
124168
estimates=estimates,
125169
targets=targets,

bayesflow/experimental/cif/cif.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _inverse(
9999
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
100100
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
101101

102-
elbo = self.log_prob(x, conditions=conditions)
102+
elbo = self.log_prob(x, conditions=conditions, training=stage == "training")
103103

104104
loss = -keras.ops.mean(elbo)
105105

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def compute_metrics(
183183
) -> dict[str, Tensor]:
184184
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
185185

186-
z, log_density = self(x, conditions=conditions, inverse=False, density=True)
186+
z, log_density = self(x, conditions=conditions, inverse=False, density=True, training=stage == "training")
187187
loss = weighted_mean(-log_density, sample_weight)
188188

189189
return base_metrics | {"loss": loss}

bayesflow/networks/point_inference_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def call(
145145
def compute_metrics(
146146
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
147147
) -> dict[str, Tensor]:
148-
output = self(x, conditions)
148+
output = self(x, conditions, training=stage == "training")
149149

150150
metrics = {}
151151
# calculate negative score as mean over all scores

bayesflow/utils/dict_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ def dicts_to_arrays(
282282
Ground-truth values corresponding to the estimates. Must match the structure and dimensionality
283283
of `estimates` in terms of first and last axis.
284284
285+
priors : dict[str, ndarray] or ndarray, optional (default = None)
286+
Prior draws. Must match the structure and dimensionality
287+
of `estimates` in terms of first and last axis.
288+
285289
dataset_ids : Sequence of integers indexing the datasets to select (default = None).
286290
By default, use all datasets.
287291

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
99
"""
1010
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
11-
Significantly slower than the unstabilized version, so use only when you need numerical stability.
11+
About 50% slower than the unstabilized version, so use only when you need numerical stability.
1212
"""
1313
log_plan = log_sinkhorn_plan(x1, x2, **kwargs)
14-
assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed)
14+
assignments = keras.random.categorical(log_plan, num_samples=1, seed=seed)
1515
assignments = keras.ops.squeeze(assignments, axis=1)
1616

1717
return assignments
@@ -20,19 +20,25 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
2020
def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None):
2121
"""
2222
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`.
23-
Significantly slower than the unstabilized version, so use only when you need numerical stability.
23+
About 50% slower than the unstabilized version, so use primarily when you need numerical stability.
2424
"""
2525
cost = euclidean(x1, x2)
26+
cost_scaled = -cost / regularization
2627

27-
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
28+
# initialize transport plan from a gaussian kernel
29+
log_plan = cost_scaled - keras.ops.max(cost_scaled)
30+
n, m = keras.ops.shape(log_plan)
31+
32+
log_a = -keras.ops.log(n)
33+
log_b = -keras.ops.log(m)
2834

2935
def contains_nans(plan):
3036
return keras.ops.any(keras.ops.isnan(plan))
3137

3238
def is_converged(plan):
33-
# for convergence, the plan should be doubly stochastic
34-
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol))
35-
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol))
39+
# for convergence, the target marginals must match
40+
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), log_b, rtol=0.0, atol=rtol + atol))
41+
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), log_a, rtol=0.0, atol=rtol + atol))
3642
return conv0 & conv1
3743

3844
def cond(_, plan):
@@ -41,8 +47,8 @@ def cond(_, plan):
4147

4248
def body(steps, plan):
4349
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
44-
plan = keras.ops.log_softmax(plan, axis=0)
45-
plan = keras.ops.log_softmax(plan, axis=1)
50+
plan = plan - keras.ops.logsumexp(plan, axis=0, keepdims=True) + log_b
51+
plan = plan - keras.ops.logsumexp(plan, axis=1, keepdims=True) + log_a
4652

4753
return steps + 1, plan
4854

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
1111
"""
1212
Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm.
1313
14-
Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a doubly stochastic
14+
Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a
1515
transport plan, containing assignment probabilities.
1616
The permutation is then sampled randomly according to the transport plan.
1717
@@ -27,12 +27,15 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
2727
:param seed: Random seed to use for sampling indices.
2828
Default: None, which means the seed will be auto-determined for non-compiled contexts.
2929
30-
:return: Tensor of shape (m,)
30+
:return: Tensor of shape (n,)
3131
Assignment indices for x2.
3232
3333
"""
3434
plan = sinkhorn_plan(x1, x2, **kwargs)
35-
assignments = keras.random.categorical(plan, num_samples=1, seed=seed)
35+
36+
# we sample from log(plan) to receive assignments of length n, corresponding to indices of x2
37+
# such that x2[assignments] matches x1
38+
assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed)
3639
assignments = keras.ops.squeeze(assignments, axis=1)
3740

3841
return assignments
@@ -42,7 +45,7 @@ def sinkhorn_plan(
4245
x1: Tensor,
4346
x2: Tensor,
4447
regularization: float = 1.0,
45-
max_steps: int = 10_000,
48+
max_steps: int = None,
4649
rtol: float = 1e-5,
4750
atol: float = 1e-8,
4851
) -> Tensor:
@@ -59,7 +62,7 @@ def sinkhorn_plan(
5962
Controls the standard deviation of the Gaussian kernel.
6063
6164
:param max_steps: Maximum number of iterations, or None to run until convergence.
62-
Default: 10_000
65+
Default: None
6366
6467
:param rtol: Relative tolerance for convergence.
6568
Default: 1e-5.
@@ -71,17 +74,20 @@ def sinkhorn_plan(
7174
The transport probabilities.
7275
"""
7376
cost = euclidean(x1, x2)
77+
cost_scaled = -cost / regularization
7478

75-
# initialize the transport plan from a gaussian kernel
76-
plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16))
79+
# initialize transport plan from a gaussian kernel
80+
# (more numerically stable version of keras.ops.exp(-cost/regularization))
81+
plan = keras.ops.exp(cost_scaled - keras.ops.max(cost_scaled))
82+
n, m = keras.ops.shape(cost)
7783

7884
def contains_nans(plan):
7985
return keras.ops.any(keras.ops.isnan(plan))
8086

8187
def is_converged(plan):
82-
# for convergence, the plan should be doubly stochastic
83-
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0, rtol=rtol, atol=atol))
84-
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0, rtol=rtol, atol=atol))
88+
# for convergence, the target marginals must match
89+
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0 / m, rtol=rtol, atol=atol))
90+
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0 / n, rtol=rtol, atol=atol))
8591
return conv0 & conv1
8692

8793
def cond(_, plan):
@@ -90,8 +96,8 @@ def cond(_, plan):
9096

9197
def body(steps, plan):
9298
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
93-
plan = keras.ops.softmax(plan, axis=0)
94-
plan = keras.ops.softmax(plan, axis=1)
99+
plan = plan / keras.ops.sum(plan, axis=0, keepdims=True) * (1.0 / m)
100+
plan = plan / keras.ops.sum(plan, axis=1, keepdims=True) * (1.0 / n)
95101

96102
return steps + 1, plan
97103

bayesflow/utils/plot_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def prepare_plot_data(
2323
figsize: tuple = None,
2424
stacked: bool = False,
2525
default_name: str = "v",
26-
) -> Mapping[str, Any]:
26+
) -> dict[str, Any]:
2727
"""
2828
Procedural wrapper that encompasses all preprocessing steps, including shape-checking, parameter name
2929
generation, layout configuration, figure initialization, and collapsing of axes.
@@ -56,6 +56,12 @@ def prepare_plot_data(
5656
Whether the plots are stacked horizontally
5757
default_name : str, optional (default = "v")
5858
The default name to use for estimates if None provided
59+
60+
Returns
61+
-------
62+
plot_data : dict[str, Any]
63+
A dictionary containing all preprocessed data and plotting objects required for visualization,
64+
including estimates, targets, variable names, figure, axes, and layout configuration.
5965
"""
6066

6167
plot_data = dicts_to_arrays(

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ dependencies = [
3636
[project.optional-dependencies]
3737
all = [
3838
# dev
39+
"ipython",
40+
"ipykernel",
3941
"jupyter",
4042
"jupyterlab",
43+
"line-profiler",
4144
"nbconvert",
42-
"ipython",
43-
"ipykernel",
4445
"pre-commit",
4546
"ruff",
4647
"tox",
4748
# docs
48-
4949
"myst-nb ~= 1.2",
5050
"numpydoc ~= 1.8",
5151
"pydata-sphinx-theme ~= 0.16",
@@ -63,6 +63,7 @@ all = [
6363
dev = [
6464
"jupyter",
6565
"jupyterlab",
66+
"line-profiler",
6667
"pre-commit",
6768
"ruff",
6869
"tox",

0 commit comments

Comments
 (0)