Skip to content

Commit d68c9dd

Browse files
vpratzhan-ol
andauthored
Breaking: parameterize MVNormalScore by inverse cholesky factor to improve stability (#545)
* breaking: parameterize MVNormalScore by inverse cholesky factor The log_prob can be completely calculated using the inverse cholesky factor L^{-1}. Using this also stabilizes the initial loss, and speeds up computation. This commit also contains two optimizations. Moving the computation of the precision matrix into the einsum, and using the sum of the logs instead of the log of a product. As the parameterization changes, this is a breaking change. * Add right_side_scale_inverse and test [no ci] The transformation necessary to undo standardization for a Cholesky factor of the precision matrix is x_ij = x_ij' / sigma_j, which is now implemented by a right_side_scale_inverse transformation_type. * Stop skipping MVN tests * Remove stray keyword argument in fill_triangular_matrix * Rename cov_chol_inv to precision_chol and update docstrings [no ci] * rename precision_chol to precision_cholesky_factor to improve clarity. * rename cov_chol to covariance_cholesky_factor * remove check_approximator_multivariate_normal_score function [no ci] --------- Co-authored-by: han-ol <g@hans.olischlaeger.com>
1 parent fb3191b commit d68c9dd

File tree

8 files changed

+53
-56
lines changed

8 files changed

+53
-56
lines changed

bayesflow/networks/standardization/standardization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def call(
117117
case "left_side_scale":
118118
# x_ij = sigma_i * x_ij'
119119
out = val * keras.ops.moveaxis(std, -1, -2)
120+
case "right_side_scale_inverse":
121+
# x_ij = x_ij' / sigma_j
122+
out = val / std
120123
case _:
121124
out = val
122125

bayesflow/scores/multivariate_normal_score.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,26 @@
1313
class MultivariateNormalScore(ParametricDistributionScore):
1414
r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))`
1515
16-
Scores a predicted mean and (Cholesky factor of the) covariance matrix with the log-score of the probability
17-
of the materialized value.
16+
Scores a predicted mean and lower-triangular Cholesky factor :math:`L` of the precision matrix :math:`P`
17+
with the log-score of the probability of the materialized value. The precision matrix is
18+
the inverse of the covariance matrix, :math:`L^T L = P = \Sigma^{-1}`.
1819
"""
1920

20-
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol",)
21+
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("precision_cholesky_factor",)
2122
"""
22-
Marks head for covariance matrix Cholesky factor as an exception for adapter transformations.
23+
Marks head for precision matrix Cholesky factor as an exception for adapter transformations.
2324
2425
This variable contains names of prediction heads that should lead to a warning when the adapter is applied
2526
in inverse direction to them.
2627
2728
For more information see :py:class:`ScoringRule`.
2829
"""
2930

30-
TRANSFORMATION_TYPE: dict[str, str] = {"cov_chol": "left_side_scale"}
31+
TRANSFORMATION_TYPE: dict[str, str] = {"precision_cholesky_factor": "right_side_scale_inverse"}
3132
"""
32-
Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors.
33+
Marks precision Cholesky factor head to handle de-standardization appropriately.
3334
34-
The appropriate inverse of the standardization operation is
35-
36-
x_ij = sigma_i * x_ij'.
35+
See :py:class:`bayesflow.networks.Standardization` for more information on supported de-standardization options.
3736
3837
For the mean head the default ("location_scale") is not overridden.
3938
"""
@@ -42,7 +41,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
4241
super().__init__(links=links, **kwargs)
4342

4443
self.dim = dim
45-
self.links = links or {"cov_chol": CholeskyFactor()}
44+
self.links = links or {"precision_cholesky_factor": CholeskyFactor()}
4645

4746
self.config = {"dim": dim}
4847

@@ -52,16 +51,16 @@ def get_config(self):
5251

5352
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
5453
self.dim = target_shape[-1]
55-
return dict(mean=(self.dim,), cov_chol=(self.dim, self.dim))
54+
return dict(mean=(self.dim,), precision_cholesky_factor=(self.dim, self.dim))
5655

57-
def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
56+
def log_prob(self, x: Tensor, mean: Tensor, precision_cholesky_factor: Tensor) -> Tensor:
5857
"""
5958
Compute the log probability density of a multivariate Gaussian distribution.
6059
6160
This function calculates the log probability density for each sample in `x` under a
62-
multivariate Gaussian distribution with the given `mean` and `cov_chol`.
61+
multivariate Gaussian distribution with the given `mean` and `precision_cholesky_factor`.
6362
64-
The computation includes the determinant of the covariance matrix, its inverse, and the quadratic
63+
The computation includes the determinant of the precision matrix, its inverse, and the quadratic
6564
form in the exponential term of the Gaussian density function.
6665
6766
Parameters
@@ -71,8 +70,9 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
7170
The shape should be compatible with broadcasting against `mean`.
7271
mean : Tensor
7372
A tensor representing the mean of the multivariate Gaussian distribution.
74-
covariance : Tensor
75-
A tensor representing the covariance matrix of the multivariate Gaussian distribution.
73+
precision_cholesky_factor : Tensor
74+
A tensor representing the lower-triangular Cholesky factor of the precision matrix
75+
of the multivariate Gaussian distribution.
7676
7777
Returns
7878
-------
@@ -82,29 +82,27 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
8282
"""
8383
diff = x - mean
8484

85-
# Calculate precision from Cholesky factors of covariance matrix
86-
cov_chol_inv = keras.ops.inv(cov_chol)
87-
precision = keras.ops.matmul(
88-
keras.ops.swapaxes(cov_chol_inv, -2, -1),
89-
cov_chol_inv,
90-
)
91-
9285
# Compute log determinant, exploiting Cholesky factors
93-
log_det_covariance = keras.ops.log(keras.ops.prod(keras.ops.diagonal(cov_chol, axis1=1, axis2=2), axis=1)) * 2
86+
log_det_covariance = -2 * keras.ops.sum(
87+
keras.ops.log(keras.ops.diagonal(precision_cholesky_factor, axis1=1, axis2=2)), axis=1
88+
)
9489

95-
# Compute the quadratic term in the exponential of the multivariate Gaussian
96-
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)
90+
# Compute the quadratic term in the exponential of the multivariate Gaussian from Cholesky factors
91+
# diff^T * precision_cholesky_factor^T * precision_cholesky_factor * diff
92+
quadratic_term = keras.ops.einsum(
93+
"...i,...ji,...jk,...k->...", diff, precision_cholesky_factor, precision_cholesky_factor, diff
94+
)
9795

9896
# Compute the log probability density
9997
log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term)
10098

10199
return log_prob
102100

103-
def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
101+
def sample(self, batch_shape: Shape, mean: Tensor, precision_cholesky_factor: Tensor) -> Tensor:
104102
"""
105103
Generate samples from a multivariate Gaussian distribution.
106104
107-
Independent standard normal samples are transformed using the Cholesky factor of the covariance matrix
105+
Independent standard normal samples are transformed using the Cholesky factor of the precision matrix
108106
to generate correlated samples.
109107
110108
Parameters
@@ -114,32 +112,34 @@ def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
114112
mean : Tensor
115113
A tensor representing the mean of the multivariate Gaussian distribution.
116114
Must have shape (batch_size, D), where D is the dimensionality of the distribution.
117-
cov_chol : Tensor
118-
A tensor representing a Cholesky factor of the covariance matrix of the multivariate Gaussian distribution.
115+
precision_cholesky_factor : Tensor
116+
A tensor representing the lower-triangular Cholesky factor of the precision matrix
117+
of the multivariate Gaussian distribution.
119118
Must have shape (batch_size, D, D), where D is the dimensionality.
120119
121120
Returns
122121
-------
123122
Tensor
124123
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
125124
"""
125+
covariance_cholesky_factor = keras.ops.inv(precision_cholesky_factor)
126126
if len(batch_shape) == 1:
127127
batch_shape = (1,) + tuple(batch_shape)
128128
batch_size, num_samples = batch_shape
129129
dim = keras.ops.shape(mean)[-1]
130130
if keras.ops.shape(mean) != (batch_size, dim):
131131
raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}")
132132

133-
if keras.ops.shape(cov_chol) != (batch_size, dim, dim):
133+
if keras.ops.shape(precision_cholesky_factor) != (batch_size, dim, dim):
134134
raise ValueError(
135135
f"covariance Cholesky factor must have shape (batch_size, {dim}, {dim}),"
136-
f"but got {keras.ops.shape(cov_chol)}"
136+
f"but got {keras.ops.shape(precision_cholesky_factor)}"
137137
)
138138

139139
# Use Cholesky decomposition to generate samples
140140
normal_samples = keras.random.normal((*batch_shape, dim))
141141

142-
scaled_normal = keras.ops.einsum("ijk,ilk->ilj", cov_chol, normal_samples)
142+
scaled_normal = keras.ops.einsum("ijk,ilk->ilj", covariance_cholesky_factor, normal_samples)
143143
samples = mean[:, None, :] + scaled_normal
144144

145145
return samples

bayesflow/utils/tensor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def stack(*items):
311311
return keras.tree.map_structure(stack, *structures)
312312

313313

314-
def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool = False):
314+
def fill_triangular_matrix(x: Tensor, upper: bool = False):
315315
"""
316316
Reshapes a batch of matrix elements into a triangular matrix (either upper or lower).
317317

tests/test_approximators/test_fit.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44
import io
55
from contextlib import redirect_stdout
6-
from tests.utils import check_approximator_multivariate_normal_score
76

87

98
@pytest.mark.skip(reason="not implemented")
@@ -20,9 +19,6 @@ def test_fit(amortizer, dataset):
2019

2120

2221
def test_loss_progress(approximator, train_dataset, validation_dataset):
23-
# as long as MultivariateNormalScore is unstable, skip fit progress test
24-
check_approximator_multivariate_normal_score(approximator)
25-
2622
approximator.compile(optimizer="AdamW")
2723
num_epochs = 3
2824

tests/test_approximators/test_log_prob.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import keras
22
import numpy as np
3-
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
3+
from tests.utils import check_combination_simulator_adapter
44

55

66
def test_approximator_log_prob(approximator, simulator, batch_size, adapter):
77
check_combination_simulator_adapter(simulator, adapter)
8-
# as long as MultivariateNormalScore is unstable, skip
9-
check_approximator_multivariate_normal_score(approximator)
108

119
num_batches = 4
1210
data = simulator.sample((num_batches * batch_size,))

tests/test_approximators/test_sample.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import keras
2-
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
2+
from tests.utils import check_combination_simulator_adapter
33

44

55
def test_approximator_sample(approximator, simulator, batch_size, adapter):
66
check_combination_simulator_adapter(simulator, adapter)
7-
# as long as MultivariateNormalScore is unstable, skip
8-
check_approximator_multivariate_normal_score(approximator)
97

108
num_batches = 4
119
data = simulator.sample((num_batches * batch_size,))

tests/test_networks/test_standardization.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import numpy as np
23
import keras
34

@@ -156,9 +157,11 @@ def test_transformation_type_both_sides_scale():
156157
np.testing.assert_allclose(cov_input, cov_standardized_and_recovered, atol=1e-4)
157158

158159

159-
def test_transformation_type_left_side_scale():
160+
@pytest.mark.parametrize("transformation_type", ["left_side_scale", "right_side_scale_inverse"])
161+
def test_transformation_type_one_side_scale(transformation_type):
160162
# Fix a known covariance and mean in original (not standardized space)
161163
covariance = np.array([[1, 0.5], [0.5, 2.0]], dtype="float32")
164+
162165
mean = np.array([1, 10], dtype="float32")
163166

164167
# Generate samples
@@ -177,16 +180,25 @@ def test_transformation_type_left_side_scale():
177180
cov_standardized = np.cov(keras.ops.convert_to_numpy(standardized), rowvar=False)
178181
cov_standardized = keras.ops.convert_to_tensor(cov_standardized)
179182
chol_standardized = keras.ops.cholesky(cov_standardized) # (dim, dim)
183+
184+
# We test the right_side_scale_inverse transformation by backtransforming a precision chol factor
185+
# instead of a covariance chol factor.
186+
if "inverse" in transformation_type:
187+
chol_standardized = keras.ops.inv(chol_standardized)
188+
180189
# Inverse standardization of covariance matrix in standardized space
181190
chol_standardized_and_recovered = layer(
182-
chol_standardized, stage="inference", forward=False, transformation_type="left_side_scale"
191+
chol_standardized, stage="inference", forward=False, transformation_type=transformation_type
183192
)
184193

185194
random_input = keras.ops.convert_to_numpy(random_input)
186195
chol_standardized_and_recovered = keras.ops.convert_to_numpy(chol_standardized_and_recovered)
187196
cov_input = np.cov(random_input, rowvar=False)
188197
chol_input = np.linalg.cholesky(cov_input)
189198

199+
if "inverse" in transformation_type:
200+
chol_input = np.linalg.inv(chol_input)
201+
190202
np.testing.assert_allclose(chol_input, chol_standardized_and_recovered, atol=1e-4)
191203

192204

tests/utils/check_combinations.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,3 @@ def check_combination_simulator_adapter(simulator, adapter):
1919
# to be used as sample weight, no error is raised currently.
2020
# Don't use this fixture combination for further tests.
2121
pytest.skip(reason="Do not use this fixture combination for further tests") # TODO: better reason
22-
23-
24-
def check_approximator_multivariate_normal_score(approximator):
25-
from bayesflow.approximators import PointApproximator
26-
from bayesflow.scores import MultivariateNormalScore
27-
28-
if isinstance(approximator, PointApproximator):
29-
for score in approximator.inference_network.scores.values():
30-
if isinstance(score, MultivariateNormalScore):
31-
pytest.skip(reason="MultivariateNormalScore is unstable")

0 commit comments

Comments
 (0)