3434 sigmoid ,
3535)
3636from pytensor .tensor .blockwise import Blockwise
37- from pytensor .tensor .einsum import _delta
3837from pytensor .tensor .elemwise import DimShuffle
3938from pytensor .tensor .exceptions import NotScalarConstantError
4039from pytensor .tensor .linalg import cholesky , det , eigh , solve_triangular , trace
@@ -1213,12 +1212,8 @@ def rv_op(cls, n, eta, sd_dist, *, size=None):
12131212 D = sd_dist .type (name = "D" ) # Make sd_dist opaque to OpFromGraph
12141213 size = D .shape [:- 1 ]
12151214
1216- # We flatten the size to make operations easier, and then rebuild it
1217- flat_size = pt .prod (size , dtype = "int64" )
1218-
1219- next_rng , C = LKJCorrRV ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1220- D_matrix = D .reshape ((flat_size , n ))
1221- C *= D_matrix [..., :, None ] * D_matrix [..., None , :]
1215+ next_rng , C = LKJCorrRV ._random_corr_matrix (rng = rng , n = n , eta = eta , size = size )
1216+ C *= D [..., :, None ] * D [..., None , :]
12221217
12231218 tril_idx = pt .tril_indices (n , k = 0 )
12241219 samples = pt .linalg .cholesky (C )[..., tril_idx [0 ], tril_idx [1 ]]
@@ -1520,53 +1515,52 @@ def make_node(self, rng, size, n, eta):
15201515
15211516 @classmethod
15221517 def rv_op (cls , n : int , eta , * , rng = None , size = None ):
1523- # We flatten the size to make operations easier, and then rebuild it
1518+ # HACK: normalize_size_param doesn't handle size=() properly
1519+ if not size :
1520+ size = None
1521+
15241522 n = pt .as_tensor (n , ndim = 0 , dtype = int )
15251523 eta = pt .as_tensor (eta , ndim = 0 )
15261524 rng = normalize_rng_param (rng )
15271525 size = normalize_size_param (size )
15281526
1529- if rv_size_is_none (size ):
1530- flat_size = 1
1531- else :
1532- flat_size = pt .prod (size , dtype = "int64" )
1527+ next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , size = size )
15331528
1534- next_rng , C = cls ._random_corr_matrix (rng = rng , n = n , eta = eta , flat_size = flat_size )
1535- C = C [0 ] if rv_size_is_none (size ) else C .reshape ((* size , n , n ))
1536-
1537- return cls (
1538- inputs = [rng , size , n , eta ],
1539- outputs = [next_rng , C ],
1540- )(rng , size , n , eta )
1529+ return cls (inputs = [rng , size , n , eta ], outputs = [next_rng , C ])(rng , size , n , eta )
15411530
15421531 @classmethod
15431532 def _random_corr_matrix (
1544- cls , rng : Variable , n : int , eta : TensorVariable , flat_size : TensorVariable
1533+ cls , rng : Variable , n : int , eta : TensorVariable , size : TensorVariable
15451534 ) -> tuple [Variable , TensorVariable ]:
15461535 # original implementation in R see:
15471536 # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
1537+ size = () if rv_size_is_none (size ) else size
15481538
15491539 beta = eta - 1.0 + n / 2.0
1550- next_rng , beta_rvs = pt .random .beta (
1551- alpha = beta , beta = beta , size = flat_size , rng = rng
1552- ).owner .outputs
1540+ next_rng , beta_rvs = pt .random .beta (alpha = beta , beta = beta , size = size , rng = rng ).owner .outputs
15531541 r12 = 2.0 * beta_rvs - 1.0
1554- P = pt .full ((flat_size , n , n ), pt .eye (n ))
1542+
1543+ P = pt .full ((* size , n , n ), pt .eye (n ))
15551544 P = P [..., 0 , 1 ].set (r12 )
15561545 P = P [..., 1 , 1 ].set (pt .sqrt (1.0 - r12 ** 2 ))
15571546 n = get_underlying_scalar_constant_value (n )
15581547
15591548 for mp1 in range (2 , n ):
15601549 beta -= 0.5
1550+
15611551 next_rng , y = pt .random .beta (
1562- alpha = mp1 / 2.0 , beta = beta , size = flat_size , rng = next_rng
1552+ alpha = mp1 / 2.0 , beta = beta , size = size , rng = next_rng
15631553 ).owner .outputs
1554+
15641555 next_rng , z = pt .random .normal (
1565- loc = 0 , scale = 1 , size = (flat_size , mp1 ), rng = next_rng
1556+ loc = 0 , scale = 1 , size = (* size , mp1 ), rng = next_rng
15661557 ).owner .outputs
1567- z = z / pt .sqrt (pt .einsum ("ij,ij->i" , z , z .copy ()))[..., np .newaxis ]
1558+
1559+ ein_sig_z = "i, i->" if z .ndim == 1 else "...ij, ...ij->...i"
1560+ z = z / pt .sqrt (pt .einsum (ein_sig_z , z , z .copy ()))[..., np .newaxis ]
15681561 P = P [..., 0 :mp1 , mp1 ].set (pt .sqrt (y [..., np .newaxis ]) * z )
15691562 P = P [..., mp1 , mp1 ].set (pt .sqrt (1.0 - y ))
1563+
15701564 C = pt .einsum ("...ji,...jk->...ik" , P , P .copy ())
15711565
15721566 return next_rng , C
@@ -1584,10 +1578,7 @@ def dist(cls, n, eta, **kwargs):
15841578
15851579 @staticmethod
15861580 def support_point (rv : TensorVariable , * args ):
1587- ndim = rv .ndim
1588-
1589- # Batched identity matrix
1590- return _delta (rv .shape , (ndim - 2 , ndim - 1 )).astype (int )
1581+ return pt .broadcast_to (pt .eye (rv .shape [- 1 ]), rv .shape )
15911582
15921583 @staticmethod
15931584 def logp (value : TensorVariable , n , eta ):
0 commit comments