diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f76a98546..9435b40fa 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2124,8 +2124,8 @@ def logp(value, rng, size, mu, sigma, *covs): sqrt_quad = sqrt_quad / pt.sqrt(eigs[:, None]) logdet = pt.sum(pt.log(eigs)) - # Square each sample - quad = pt.batched_dot(sqrt_quad.T, sqrt_quad.T) + # Square each sample - compute squared norm for each sample + quad = pt.sum(sqrt_quad.T**2, axis=-1) if onedim: quad = quad[0]