1313class MultivariateNormalScore (ParametricDistributionScore ):
1414 r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))`
1515
16- Scores a predicted mean and (Cholesky factor of the) covariance matrix with the log-score of the probability
17- of the materialized value.
16+ Scores a predicted mean and lower-triangular Cholesky factor :math:`L` of the precision matrix :math:`P`
17+ with the log-score of the probability of the materialized value. The precision matrix is
18+ the inverse of the covariance matrix, :math:`L^T L = P = \Sigma^{-1}`.
1819 """
1920
20- NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol " ,)
21+ NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("precision_cholesky_factor " ,)
2122 """
22- Marks head for covariance matrix Cholesky factor as an exception for adapter transformations.
23+ Marks head for precision matrix Cholesky factor as an exception for adapter transformations.
2324
2425 This variable contains names of prediction heads that should lead to a warning when the adapter is applied
2526 in inverse direction to them.
2627
2728 For more information see :py:class:`ScoringRule`.
2829 """
2930
30- TRANSFORMATION_TYPE : dict [str , str ] = {"cov_chol " : "left_side_scale " }
31+ TRANSFORMATION_TYPE : dict [str , str ] = {"precision_cholesky_factor " : "right_side_scale_inverse " }
3132 """
32- Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors .
33+ Marks precision Cholesky factor head to handle de-standardization appropriately .
3334
34- The appropriate inverse of the standardization operation is
35-
36- x_ij = sigma_i * x_ij'.
35+ See :py:class:`bayesflow.networks.Standardization` for more information on supported de-standardization options.
3736
3837 For the mean head the default ("location_scale") is not overridden.
3938 """
@@ -42,7 +41,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
4241 super ().__init__ (links = links , ** kwargs )
4342
4443 self .dim = dim
45- self .links = links or {"cov_chol " : CholeskyFactor ()}
44+ self .links = links or {"precision_cholesky_factor " : CholeskyFactor ()}
4645
4746 self .config = {"dim" : dim }
4847
@@ -52,16 +51,16 @@ def get_config(self):
5251
5352 def get_head_shapes_from_target_shape (self , target_shape : Shape ) -> dict [str , Shape ]:
5453 self .dim = target_shape [- 1 ]
55- return dict (mean = (self .dim ,), cov_chol = (self .dim , self .dim ))
54+ return dict (mean = (self .dim ,), precision_cholesky_factor = (self .dim , self .dim ))
5655
57- def log_prob (self , x : Tensor , mean : Tensor , cov_chol : Tensor ) -> Tensor :
56+ def log_prob (self , x : Tensor , mean : Tensor , precision_cholesky_factor : Tensor ) -> Tensor :
5857 """
5958 Compute the log probability density of a multivariate Gaussian distribution.
6059
6160 This function calculates the log probability density for each sample in `x` under a
62- multivariate Gaussian distribution with the given `mean` and `cov_chol `.
61+ multivariate Gaussian distribution with the given `mean` and `precision_cholesky_factor `.
6362
64- The computation includes the determinant of the covariance matrix, its inverse, and the quadratic
63+ The computation includes the determinant of the precision matrix, its inverse, and the quadratic
6564 form in the exponential term of the Gaussian density function.
6665
6766 Parameters
@@ -71,8 +70,9 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
7170 The shape should be compatible with broadcasting against `mean`.
7271 mean : Tensor
7372 A tensor representing the mean of the multivariate Gaussian distribution.
74- covariance : Tensor
75- A tensor representing the covariance matrix of the multivariate Gaussian distribution.
73+ precision_cholesky_factor : Tensor
74+ A tensor representing the lower-triangular Cholesky factor of the precision matrix
75+ of the multivariate Gaussian distribution.
7676
7777 Returns
7878 -------
@@ -82,29 +82,27 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
8282 """
8383 diff = x - mean
8484
85- # Calculate precision from Cholesky factors of covariance matrix
86- cov_chol_inv = keras .ops .inv (cov_chol )
87- precision = keras .ops .matmul (
88- keras .ops .swapaxes (cov_chol_inv , - 2 , - 1 ),
89- cov_chol_inv ,
90- )
91-
9285 # Compute log determinant, exploiting Cholesky factors
93- log_det_covariance = keras .ops .log (keras .ops .prod (keras .ops .diagonal (cov_chol , axis1 = 1 , axis2 = 2 ), axis = 1 )) * 2
86+ log_det_covariance = - 2 * keras .ops .sum (
87+ keras .ops .log (keras .ops .diagonal (precision_cholesky_factor , axis1 = 1 , axis2 = 2 )), axis = 1
88+ )
9489
95- # Compute the quadratic term in the exponential of the multivariate Gaussian
96- quadratic_term = keras .ops .einsum ("...i,...ij,...j->..." , diff , precision , diff )
90+ # Compute the quadratic term in the exponential of the multivariate Gaussian from Cholesky factors
91+ # diff^T * precision_cholesky_factor^T * precision_cholesky_factor * diff
92+ quadratic_term = keras .ops .einsum (
93+ "...i,...ji,...jk,...k->..." , diff , precision_cholesky_factor , precision_cholesky_factor , diff
94+ )
9795
9896 # Compute the log probability density
9997 log_prob = - 0.5 * (self .dim * keras .ops .log (2 * math .pi ) + log_det_covariance + quadratic_term )
10098
10199 return log_prob
102100
103- def sample (self , batch_shape : Shape , mean : Tensor , cov_chol : Tensor ) -> Tensor :
101+ def sample (self , batch_shape : Shape , mean : Tensor , precision_cholesky_factor : Tensor ) -> Tensor :
104102 """
105103 Generate samples from a multivariate Gaussian distribution.
106104
107- Independent standard normal samples are transformed using the Cholesky factor of the covariance matrix
105+ Independent standard normal samples are transformed using the Cholesky factor of the precision matrix
108106 to generate correlated samples.
109107
110108 Parameters
@@ -114,32 +112,34 @@ def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
114112 mean : Tensor
115113 A tensor representing the mean of the multivariate Gaussian distribution.
116114 Must have shape (batch_size, D), where D is the dimensionality of the distribution.
117- cov_chol : Tensor
118- A tensor representing a Cholesky factor of the covariance matrix of the multivariate Gaussian distribution.
115+ precision_cholesky_factor : Tensor
116+ A tensor representing the lower-triangular Cholesky factor of the precision matrix
117+ of the multivariate Gaussian distribution.
119118 Must have shape (batch_size, D, D), where D is the dimensionality.
120119
121120 Returns
122121 -------
123122 Tensor
124123 A tensor of shape (batch_size, num_samples, D) containing the generated samples.
125124 """
125+ covariance_cholesky_factor = keras .ops .inv (precision_cholesky_factor )
126126 if len (batch_shape ) == 1 :
127127 batch_shape = (1 ,) + tuple (batch_shape )
128128 batch_size , num_samples = batch_shape
129129 dim = keras .ops .shape (mean )[- 1 ]
130130 if keras .ops .shape (mean ) != (batch_size , dim ):
131131 raise ValueError (f"mean must have shape (batch_size, { dim } ), but got { keras .ops .shape (mean )} " )
132132
133- if keras .ops .shape (cov_chol ) != (batch_size , dim , dim ):
133+ if keras .ops .shape (precision_cholesky_factor ) != (batch_size , dim , dim ):
134134 raise ValueError (
135135 f"covariance Cholesky factor must have shape (batch_size, { dim } , { dim } ),"
136- f"but got { keras .ops .shape (cov_chol )} "
136+ f"but got { keras .ops .shape (precision_cholesky_factor )} "
137137 )
138138
139139 # Use Cholesky decomposition to generate samples
140140 normal_samples = keras .random .normal ((* batch_shape , dim ))
141141
142- scaled_normal = keras .ops .einsum ("ijk,ilk->ilj" , cov_chol , normal_samples )
142+ scaled_normal = keras .ops .einsum ("ijk,ilk->ilj" , covariance_cholesky_factor , normal_samples )
143143 samples = mean [:, None , :] + scaled_normal
144144
145145 return samples
0 commit comments