Skip to content

Commit 76500df

Browse files
authored
Stabilize MultivariateNormalScore by constraining initialization in PositiveDefinite link (#469)
* Refactor fill_triangular_matrix * stable positive definite link, fix for #468 * Minor changes to docstring * Remove self.built=True that prevented registering layer norm in build() * np -> keras.ops
1 parent a4d58c9 commit 76500df

File tree

3 files changed

+82
-54
lines changed

3 files changed

+82
-54
lines changed

bayesflow/links/positive_definite.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import keras
22

33
from bayesflow.types import Tensor
4-
from bayesflow.utils import layer_kwargs, fill_triangular_matrix
4+
from bayesflow.utils import layer_kwargs, fill_triangular_matrix, positive_diag
55
from bayesflow.utils.serialization import serializable
66

77

@@ -11,16 +11,21 @@ class PositiveDefinite(keras.Layer):
1111

1212
def __init__(self, **kwargs):
1313
super().__init__(**layer_kwargs(kwargs))
14-
self.built = True
14+
15+
self.layer_norm = keras.layers.LayerNormalization()
1516

1617
def call(self, inputs: Tensor) -> Tensor:
17-
# Build cholesky factor from inputs
18-
L = fill_triangular_matrix(inputs, positive_diag=True)
18+
# normalize the activation at initialization time mean = 0.0, std = 0.1
19+
inputs = self.layer_norm(inputs) / 10
20+
21+
# form a cholesky factor
22+
L = fill_triangular_matrix(inputs)
23+
L = positive_diag(L)
1924

20-
# calculate positive definite matrix from cholesky factors
25+
# calculate positive definite matrix from cholesky factors:
2126
psd = keras.ops.matmul(
2227
L,
23-
keras.ops.moveaxis(L, -2, -1), # L transposed
28+
keras.ops.swapaxes(L, -2, -1), # L transposed
2429
)
2530
return psd
2631

@@ -31,13 +36,14 @@ def compute_output_shape(self, input_shape):
3136

3237
def compute_input_shape(self, output_shape):
3338
"""
34-
Returns the shape of parameterization of a cholesky factor triangular matrix.
39+
Returns the shape of parameterization of a Cholesky factor triangular matrix.
3540
36-
There are m nonzero elements of a lower triangular nxn matrix with m = n * (n + 1) / 2.
41+
There are :math:`m` nonzero elements of a lower triangular :math:`n \\times n` matrix with
42+
:math:`m = n (n + 1) / 2`, so for output shape (..., n, n) the returned shape is (..., m).
3743
38-
Example
39-
-------
40-
>>> PositiveDefinite().compute_output_shape((None, 3, 3))
44+
Examples
45+
--------
46+
>>> PositiveDefinite().compute_input_shape((None, 3, 3))
4147
6
4248
"""
4349
n = output_shape[-1]

bayesflow/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,15 @@
8888
expand_right_as,
8989
expand_right_to,
9090
expand_tile,
91+
fill_triangular_matrix,
9192
pad,
93+
positive_diag,
9294
searchsorted,
9395
size_of,
9496
stack_valid,
9597
tile_axis,
9698
tree_concatenate,
9799
tree_stack,
98-
fill_triangular_matrix,
99100
weighted_mean,
100101
)
101102

bayesflow/utils/tensor_utils.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,6 @@ def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool =
310310
Batch of flattened nonzero matrix elements for triangular matrix.
311311
upper : bool
312312
Return upper triangular matrix if True, else lower triangular matrix. Default is False.
313-
positive_diag : bool
314-
Whether to apply a softplus operation to diagonal elements. Default is False.
315313
316314
Returns
317315
-------
@@ -327,47 +325,70 @@ def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool =
327325
batch_shape = x.shape[:-1]
328326
m = x.shape[-1]
329327

330-
if m == 1:
331-
y = keras.ops.reshape(x, (-1, 1, 1))
332-
if positive_diag:
333-
y = keras.activations.softplus(y)
334-
return y
335-
336-
# Calculate matrix shape
337-
n = (0.25 + 2 * m) ** 0.5 - 0.5
338-
if not np.isclose(np.floor(n), n):
339-
raise ValueError(f"Input right-most shape ({m}) does not correspond to a triangular matrix.")
340-
else:
341-
n = int(n)
342-
343-
# Trick: Create triangular matrix by concatenating with a flipped version of its tail, then reshape.
344-
x_tail = keras.ops.take(x, indices=list(range((m - (n**2 - m)), x.shape[-1])), axis=-1)
345-
if not upper:
346-
y = keras.ops.concatenate([x_tail, keras.ops.flip(x, axis=-1)], axis=len(batch_shape))
347-
y = keras.ops.reshape(y, (-1, n, n))
348-
y = keras.ops.tril(y)
349-
350-
if positive_diag:
351-
y_offdiag = keras.ops.tril(y, k=-1)
352-
# carve out diagonal, by setting upper and lower offdiagonals to zero
353-
y_diag = keras.ops.tril(
354-
keras.ops.triu(keras.activations.softplus(y)), # apply softplus to enforce positivity
328+
if m > 1: # Matrix is larger than than 1x1
329+
# Calculate matrix shape
330+
n = (0.25 + 2 * m) ** 0.5 - 0.5
331+
if not np.isclose(np.floor(n), n):
332+
raise ValueError(f"Input right-most shape ({m}) does not correspond to a triangular matrix.")
333+
else:
334+
n = int(n)
335+
336+
# Trick: Create triangular matrix by concatenating with a flipped version of itself, then reshape.
337+
if not upper:
338+
x_list = [x, keras.ops.flip(x[..., n:], axis=-1)]
339+
340+
y = keras.ops.concatenate(x_list, axis=len(batch_shape))
341+
y = keras.ops.reshape(y, (-1, n, n))
342+
y = keras.ops.tril(y)
343+
344+
else:
345+
x_list = [x[..., n:], keras.ops.flip(x, axis=-1)]
346+
347+
y = keras.ops.concatenate(x_list, axis=len(batch_shape))
348+
y = keras.ops.reshape(y, (-1, n, n))
349+
y = keras.ops.triu(
350+
y,
355351
)
356-
y = y_diag + y_offdiag
357352

358-
else:
359-
y = keras.ops.concatenate([x, keras.ops.flip(x_tail, axis=-1)], axis=len(batch_shape))
360-
y = keras.ops.reshape(y, (-1, n, n))
361-
y = keras.ops.triu(
362-
y,
363-
)
364-
365-
if positive_diag:
366-
y_offdiag = keras.ops.triu(y, k=1)
367-
# carve out diagonal, by setting upper and lower offdiagonals to zero
368-
y_diag = keras.ops.tril(
369-
keras.ops.triu(keras.activations.softplus(y)), # apply softplus to enforce positivity
370-
)
371-
y = y_diag + y_offdiag
353+
else: # Matrix is 1x1
354+
y = keras.ops.reshape(x, (-1, 1, 1))
372355

373356
return y
357+
358+
359+
def positive_diag(x: Tensor, method="default") -> Tensor:
360+
"""
361+
Ensures that matrix elements on diagonal are positive.
362+
363+
Parameters
364+
----------
365+
x : Tensor of shape (batch_size, n, n)
366+
Batch of matrices.
367+
method : str, optional
368+
Method by which to ensure positivity of diagonal entries. Choose from
369+
- "shifted_softplus": softplus(x + 0.5413)
370+
- "exp": exp(x)
371+
Both methods map a matrix filled with zeros to the unit matrix.
372+
Default is "shifted_softplus".
373+
374+
Returns
375+
-------
376+
Tensor of shape (batch_size, n, n)
377+
"""
378+
# ensure positivity
379+
match method:
380+
case "default" | "shifted_softplus":
381+
x_positive = keras.activations.softplus(x + 0.5413)
382+
case "exp":
383+
x_positive = keras.ops.exp(x)
384+
385+
# zero all offdiagonals
386+
x_diag_positive = keras.ops.tril(keras.ops.triu(x_positive))
387+
388+
# zero diagonal entries
389+
x_offdiag = keras.ops.triu(x, k=1) + keras.ops.tril(x, k=-1)
390+
391+
# sum to get full matrices with softplus applied only to diagonal entries
392+
x = x_diag_positive + x_offdiag
393+
394+
return x

0 commit comments

Comments
 (0)