3434 sigmoid ,
3535)
3636from pytensor .tensor .blockwise import Blockwise
37+ from pytensor .tensor .einsum import _delta
3738from pytensor .tensor .elemwise import DimShuffle
3839from pytensor .tensor .exceptions import NotScalarConstantError
3940from pytensor .tensor .linalg import cholesky , det , eigh , solve_triangular , trace
7677)
7778from pymc .distributions .transforms import (
7879 CholeskyCorrTransform ,
79- Interval ,
8080 ZeroSumTransform ,
8181 _default_transform ,
8282)
@@ -1157,12 +1157,12 @@ def _lkj_normalizing_constant(eta, n):
11571157 if not isinstance (n , int ):
11581158 raise NotImplementedError ("n must be an integer" )
11591159 if eta == 1 :
1160- result = gammaln (2.0 * pt .arange (1 , int ((n - 1 ) / 2 ) + 1 )).sum ()
1160+ result = gammaln (2.0 * pt .arange (1 , ((n - 1 ) / 2 ) + 1 )).sum ()
11611161 if n % 2 == 1 :
11621162 result += (
11631163 0.25 * (n ** 2 - 1 ) * pt .log (np .pi )
11641164 - 0.25 * (n - 1 ) ** 2 * pt .log (2.0 )
1165- - (n - 1 ) * gammaln (int (( n + 1 ) / 2 ) )
1165+ - (n - 1 ) * gammaln (( n + 1 ) / 2 )
11661166 )
11671167 else :
11681168 result += (
@@ -1504,7 +1504,7 @@ def helper_deterministics(cls, n, packed_chol):
15041504
15051505class LKJCorrRV (SymbolicRandomVariable ):
15061506 name = "lkjcorr"
1507- extended_signature = "[rng],[size],(),()->[rng],(n)"
1507+ extended_signature = "[rng],[size],(),()->[rng],(n,n )"
15081508 _print_name = ("LKJCorrRV" , "\\ operatorname{LKJCorrRV}" )
15091509
15101510 def make_node (self , rng , size , n , eta ):
@@ -1532,23 +1532,13 @@ def rv_op(cls, n: int, eta, *, rng=None, size=None):
15321532 flat_size = pt .prod (size , dtype = "int64" )
15331533
15341534 next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1535-
1536- triu_idx = pt .triu_indices (n , k = 1 )
1537- samples = C [..., triu_idx [0 ], triu_idx [1 ]]
1538-
1539- if rv_size_is_none (size ):
1540- samples = samples [0 ]
1541- else :
1542- dist_shape = (n * (n - 1 )) // 2
1543- samples = pt .reshape (samples , (* size , dist_shape ))
1535+ C = C [0 ] if rv_size_is_none (size ) else C .reshape ((* size , n , n ))
15441536
15451537 return cls (
15461538 inputs = [rng , size , n , eta ],
1547- outputs = [next_rng , samples ],
1539+ outputs = [next_rng , C ],
15481540 )(rng , size , n , eta )
15491541
1550- return samples
1551-
15521542 @classmethod
15531543 def _random_corr_matrix (
15541544 cls , rng : Variable , n : int , eta : TensorVariable , flat_size : TensorVariable
@@ -1565,6 +1555,7 @@ def _random_corr_matrix(
15651555 P = P [..., 0 , 1 ].set (r12 )
15661556 P = P [..., 1 , 1 ].set (pt .sqrt (1.0 - r12 ** 2 ))
15671557 n = get_underlying_scalar_constant_value (n )
1558+
15681559 for mp1 in range (2 , n ):
15691560 beta -= 0.5
15701561 next_rng , y = pt .random .beta (
@@ -1577,17 +1568,10 @@ def _random_corr_matrix(
15771568 P = P [..., 0 :mp1 , mp1 ].set (pt .sqrt (y [..., np .newaxis ]) * z )
15781569 P = P [..., mp1 , mp1 ].set (pt .sqrt (1.0 - y ))
15791570 C = pt .einsum ("...ji,...jk->...ik" , P , P .copy ())
1580- return next_rng , C
1581-
15821571
1583- class MultivariateIntervalTransform (Interval ):
1584- name = "interval"
1585-
1586- def log_jac_det (self , * args ):
1587- return super ().log_jac_det (* args ).sum (- 1 )
1572+ return next_rng , C
15881573
15891574
1590- # Returns list of upper triangular values
15911575class _LKJCorr (BoundedContinuous ):
15921576 rv_type = LKJCorrRV
15931577 rv_op = LKJCorrRV .rv_op
@@ -1598,10 +1582,15 @@ def dist(cls, n, eta, **kwargs):
15981582 eta = pt .as_tensor_variable (eta )
15991583 return super ().dist ([n , eta ], ** kwargs )
16001584
1601- def support_point (rv , * args ):
1602- return pt .zeros_like (rv )
1585+ @staticmethod
1586+ def support_point (rv : TensorVariable , * args ):
1587+ ndim = rv .ndim
16031588
1604- def logp (value , n , eta ):
1589+ # Batched identity matrix
1590+ return _delta (rv .shape , (ndim - 2 , ndim - 1 )).astype (int )
1591+
1592+ @staticmethod
1593+ def logp (value : TensorVariable , n , eta ):
16051594 """
16061595 Calculate logp of LKJ distribution at specified value.
16071596
@@ -1614,31 +1603,20 @@ def logp(value, n, eta):
16141603 -------
16151604 TensorVariable
16161605 """
1617- if value .ndim > 1 :
1618- raise NotImplementedError ("LKJCorr logp is only implemented for vector values (ndim=1)" )
1619-
1620- # TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
1621- # n (or else find a different expression)
1606+ # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
16221607 try :
16231608 n = int (get_underlying_scalar_constant_value (n ))
16241609 except NotScalarConstantError :
16251610 raise NotImplementedError ("logp only implemented for constant `n`" )
16261611
1627- shape = n * (n - 1 ) // 2
1628- tri_index = np .zeros ((n , n ), dtype = "int32" )
1629- tri_index [np .triu_indices (n , k = 1 )] = np .arange (shape )
1630- tri_index [np .triu_indices (n , k = 1 )[::- 1 ]] = np .arange (shape )
1631-
1632- value = pt .take (value , tri_index )
1633- value = pt .fill_diagonal (value , 1 )
1634-
1635- # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
16361612 try :
16371613 eta = float (get_underlying_scalar_constant_value (eta ))
16381614 except NotScalarConstantError :
16391615 raise NotImplementedError ("logp only implemented for constant `eta`" )
1616+
16401617 result = _lkj_normalizing_constant (eta , n )
16411618 result += (eta - 1.0 ) * pt .log (det (value ))
1619+
16421620 return check_parameters (
16431621 result ,
16441622 value >= - 1 ,
@@ -1675,10 +1653,6 @@ class LKJCorr:
16751653 The shape parameter (eta > 0) of the LKJ distribution. eta = 1
16761654 implies a uniform distribution of the correlation matrices;
16771655 larger values put more weight on matrices with few correlations.
1678- return_matrix : bool, default=False
1679- If True, returns the full correlation matrix.
1680- False only returns the values of the upper triangular matrix excluding
1681- diagonal in a single vector of length n(n-1)/2 for memory efficiency
16821656
16831657 Notes
16841658 -----
@@ -1693,7 +1667,7 @@ class LKJCorr:
16931667 # Define the vector of fixed standard deviations
16941668 sds = 3 * np.ones(10)
16951669
1696- corr = pm.LKJCorr("corr", eta=4, n=10, return_matrix=True )
1670+ corr = pm.LKJCorr("corr", eta=4, n=10)
16971671
16981672 # Define a new MvNormal with the given correlation matrix
16991673 vals = sds * pm.MvNormal("vals", mu=np.zeros(10), cov=corr, shape=10)
@@ -1703,10 +1677,6 @@ class LKJCorr:
17031677 chol = pt.linalg.cholesky(corr)
17041678 vals = sds * pt.dot(chol, vals_raw)
17051679
1706- # The matrix is internally still sampled as a upper triangular vector
1707- # If you want access to it in matrix form in the trace, add
1708- pm.Deterministic("corr_mat", corr)
1709-
17101680
17111681 References
17121682 ----------
@@ -1716,26 +1686,28 @@ class LKJCorr:
17161686 100(9), pp.1989-2001.
17171687 """
17181688
1719- def __new__ (cls , name , n , eta , * , return_matrix = False , ** kwargs ):
1720- c_vec = _LKJCorr (name , eta = eta , n = n , ** kwargs )
1721- if not return_matrix :
1722- return c_vec
1723- else :
1724- return cls .vec_to_corr_mat (c_vec , n )
1725-
1726- @classmethod
1727- def dist (cls , n , eta , * , return_matrix = False , ** kwargs ):
1728- c_vec = _LKJCorr .dist (eta = eta , n = n , ** kwargs )
1729- if not return_matrix :
1730- return c_vec
1731- else :
1732- return cls .vec_to_corr_mat (c_vec , n )
1689+ def __new__ (cls , name , n , eta , ** kwargs ):
1690+ return_matrix = kwargs .pop ("return_matrix" , None )
1691+ if return_matrix is not None :
1692+ warnings .warn (
1693+ "The `return_matrix` argument is deprecated and has no effect. "
1694+ "LKJCorr always returns the correlation matrix." ,
1695+ DeprecationWarning ,
1696+ stacklevel = 2 ,
1697+ )
1698+ return _LKJCorr (name , eta = eta , n = n , ** kwargs )
17331699
17341700 @classmethod
1735- def vec_to_corr_mat (cls , vec , n ):
1736- tri = pt .zeros (pt .concatenate ([vec .shape [:- 1 ], (n , n )]))
1737- tri = pt .subtensor .set_subtensor (tri [(..., * np .triu_indices (n , 1 ))], vec )
1738- return tri + pt .moveaxis (tri , - 2 , - 1 ) + pt .diag (pt .ones (n ))
1701+ def dist (cls , n , eta , ** kwargs ):
1702+ return_matrix = kwargs .pop ("return_matrix" , None )
1703+ if return_matrix is not None :
1704+ warnings .warn (
1705+ "The `return_matrix` argument is deprecated and has no effect. "
1706+ "LKJCorr always returns the correlation matrix." ,
1707+ DeprecationWarning ,
1708+ stacklevel = 2 ,
1709+ )
1710+ return _LKJCorr .dist (eta = eta , n = n , ** kwargs )
17391711
17401712
17411713class MatrixNormalRV (RandomVariable ):
0 commit comments