3434 sigmoid ,
3535)
3636from pytensor .tensor .blockwise import Blockwise
37- from pytensor .tensor .einsum import _delta
3837from pytensor .tensor .elemwise import DimShuffle
3938from pytensor .tensor .exceptions import NotScalarConstantError
4039from pytensor .tensor .linalg import cholesky , det , eigh , solve_triangular , trace
@@ -1520,53 +1519,50 @@ def make_node(self, rng, size, n, eta):
15201519
15211520 @classmethod
15221521 def rv_op (cls , n : int , eta , * , rng = None , size = None ):
1523- # We flatten the size to make operations easier, and then rebuild it
15241522 n = pt .as_tensor (n , ndim = 0 , dtype = int )
15251523 eta = pt .as_tensor (eta , ndim = 0 )
15261524 rng = normalize_rng_param (rng )
15271525 size = normalize_size_param (size )
15281526
1529- if rv_size_is_none (size ):
1530- flat_size = 1
1531- else :
1532- flat_size = pt .prod (size , dtype = "int64" )
1527+ next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , size = size )
15331528
1534- next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1535- C = C [0 ] if rv_size_is_none (size ) else C .reshape ((* size , n , n ))
1536-
1537- return cls (
1538- inputs = [rng , size , n , eta ],
1539- outputs = [next_rng , C ],
1540- )(rng , size , n , eta )
1529+ return cls (inputs = [rng , size , n , eta ], outputs = [next_rng , C ])(rng , size , n , eta )
15411530
15421531 @classmethod
15431532 def _random_corr_matrix (
1544- cls , rng : Variable , n : int , eta : TensorVariable , flat_size : TensorVariable
1533+ cls , rng : Variable , n : int , eta : TensorVariable , size : TensorVariable
15451534 ) -> tuple [Variable , TensorVariable ]:
15461535 # original implementation in R see:
15471536 # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
15481537
1538+ size_is_none = rv_size_is_none (size )
1539+ size = () if size_is_none else size
1540+ ein_sig_z = "i, i->" if size_is_none else "...ij, ...ij->...i"
1541+
15491542 beta = eta - 1.0 + n / 2.0
1550- next_rng , beta_rvs = pt .random .beta (
1551- alpha = beta , beta = beta , size = flat_size , rng = rng
1552- ).owner .outputs
1543+ next_rng , beta_rvs = pt .random .beta (alpha = beta , beta = beta , size = size , rng = rng ).owner .outputs
15531544 r12 = 2.0 * beta_rvs - 1.0
1554- P = pt .full ((flat_size , n , n ), pt .eye (n ))
1545+
1546+ P = pt .full ((* size , n , n ), pt .eye (n ))
15551547 P = P [..., 0 , 1 ].set (r12 )
15561548 P = P [..., 1 , 1 ].set (pt .sqrt (1.0 - r12 ** 2 ))
15571549 n = get_underlying_scalar_constant_value (n )
15581550
15591551 for mp1 in range (2 , n ):
15601552 beta -= 0.5
1553+
15611554 next_rng , y = pt .random .beta (
1562- alpha = mp1 / 2.0 , beta = beta , size = flat_size , rng = next_rng
1555+ alpha = mp1 / 2.0 , beta = beta , size = size , rng = next_rng
15631556 ).owner .outputs
1557+
15641558 next_rng , z = pt .random .normal (
1565- loc = 0 , scale = 1 , size = (flat_size , mp1 ), rng = next_rng
1559+ loc = 0 , scale = 1 , size = (* size , mp1 ), rng = next_rng
15661560 ).owner .outputs
1567- z = z / pt .sqrt (pt .einsum ("ij,ij->i" , z , z .copy ()))[..., np .newaxis ]
1561+
1562+ z = z / pt .sqrt (pt .einsum (ein_sig_z , z , z .copy ()))[..., np .newaxis ]
15681563 P = P [..., 0 :mp1 , mp1 ].set (pt .sqrt (y [..., np .newaxis ]) * z )
15691564 P = P [..., mp1 , mp1 ].set (pt .sqrt (1.0 - y ))
1565+
15701566 C = pt .einsum ("...ji,...jk->...ik" , P , P .copy ())
15711567
15721568 return next_rng , C
@@ -1584,10 +1580,7 @@ def dist(cls, n, eta, **kwargs):
15841580
15851581 @staticmethod
15861582 def support_point (rv : TensorVariable , * args ):
1587- ndim = rv .ndim
1588-
1589- # Batched identity matrix
1590- return _delta (rv .shape , (ndim - 2 , ndim - 1 )).astype (int )
1583+ return pt .broadcast_to (pt .eye (rv .shape [- 1 ]), rv .shape )
15911584
15921585 @staticmethod
15931586 def logp (value : TensorVariable , n , eta ):
0 commit comments