@@ -180,29 +180,32 @@ def _generic_rv_add(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariab
180180# Constant - Constant Arithmetic
181181########################################################################################
182182
183- _add_fns [(_Constant , _Constant )] = _Constant ._binary_operator_factory (operator .add )
184- _sub_fns [(_Constant , _Constant )] = _Constant ._binary_operator_factory (operator .sub )
185- _mul_fns [(_Constant , _Constant )] = _Constant ._binary_operator_factory (operator .mul )
186- _matmul_fns [(_Constant , _Constant )] = _Constant ._binary_operator_factory (
183+ _constant_constant_operator_factory = (
184+ _Constant ._binary_operator_factory # pylint: disable=protected-access
185+ )
186+
187+ _add_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (operator .add )
188+ _sub_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (operator .sub )
189+ _mul_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (operator .mul )
190+ _matmul_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (
187191 operator .matmul
188192)
189- _truediv_fns [(_Constant , _Constant )] = _Constant . _binary_operator_factory (
193+ _truediv_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (
190194 operator .truediv
191195)
192- _floordiv_fns [(_Constant , _Constant )] = _Constant . _binary_operator_factory (
196+ _floordiv_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (
193197 operator .floordiv
194198)
195- _mod_fns [(_Constant , _Constant )] = _Constant . _binary_operator_factory (operator .mod )
196- _divmod_fns [(_Constant , _Constant )] = _Constant . _binary_operator_factory (divmod )
197- _pow_fns [(_Constant , _Constant )] = _Constant . _binary_operator_factory (operator .pow )
199+ _mod_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (operator .mod )
200+ _divmod_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (divmod )
201+ _pow_fns [(_Constant , _Constant )] = _constant_constant_operator_factory (operator .pow )
198202
199203########################################################################################
200204# Normal - Normal Arithmetic
201205########################################################################################
202206
203- _add_fns [(_Normal , _Normal )] = _Normal ._add_normal
204- _sub_fns [(_Normal , _Normal )] = _Normal ._sub_normal
205-
207+ _add_fns [(_Normal , _Normal )] = _Normal ._add_normal # pylint: disable=protected-access
208+ _sub_fns [(_Normal , _Normal )] = _Normal ._sub_normal # pylint: disable=protected-access
206209
207210########################################################################################
208211# Normal - Constant Arithmetic
@@ -254,16 +257,16 @@ def _mul_normal_constant(
254257 return _Constant (
255258 support = backend .zeros_like (norm_rv .mean ),
256259 )
260+
261+ if norm_rv .cov_cholesky_is_precomputed :
262+ cov_cholesky = constant_rv .support * norm_rv .cov_cholesky
257263 else :
258- if norm_rv .cov_cholesky_is_precomputed :
259- cov_cholesky = constant_rv .support * norm_rv .cov_cholesky
260- else :
261- cov_cholesky = None
262- return _Normal (
263- mean = constant_rv .support * norm_rv .mean ,
264- cov = (constant_rv .support ** 2 ) * norm_rv .cov ,
265- cov_cholesky = cov_cholesky ,
266- )
264+ cov_cholesky = None
265+ return _Normal (
266+ mean = constant_rv .support * norm_rv .mean ,
267+ cov = (constant_rv .support ** 2 ) * norm_rv .cov ,
268+ cov_cholesky = cov_cholesky ,
269+ )
267270
268271 return NotImplemented
269272
@@ -275,7 +278,8 @@ def _mul_normal_constant(
275278def _matmul_normal_constant (norm_rv : _Normal , constant_rv : _Constant ) -> _Normal :
276279 """Normal random variable multiplied with a vector or matrix.
277280
278- Computes the distribution of the random variable :math:`Y = XA`, where :math:`X` is a matrix- or multi-variate normal random variable and :math:`A` a constant.
281+ Computes the distribution of the random variable :math:`Y = XA`, where :math:`X` is
282+ a matrix- or multi-variate normal random variable and :math:`A` a constant.
279283 """
280284 if norm_rv .ndim == 1 or (norm_rv .ndim == 2 and norm_rv .shape [0 ] == 1 ):
281285 if norm_rv .cov_cholesky_is_precomputed :
@@ -292,25 +296,25 @@ def _matmul_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal
292296 cov = cov .reshape ((1 , 1 ))
293297
294298 return _Normal (mean = mean , cov = cov , cov_cholesky = cov_cholesky )
299+
300+ # This part does not do the Cholesky update,
301+ # because of performance configurations: currently, there is no way of switching
302+ # the Cholesky updates off, which might affect (large, potentially sparse)
303+ # covariance matrices of matrix-variate Normal RVs. See Issue #335.
304+ if constant_rv .support .ndim == 1 :
305+ constant_rv_support = constant_rv .support [:, None ]
295306 else :
296- # This part does not do the Cholesky update,
297- # because of performance configurations: currently, there is no way of switching
298- # the Cholesky updates off, which might affect (large, potentially sparse) covariance matrices
299- # of matrix-variate Normal RVs. See Issue #335.
300- if constant_rv .support .ndim == 1 :
301- constant_rv_support = constant_rv .support [:, None ]
302- else :
303- constant_rv_support = constant_rv .support
307+ constant_rv_support = constant_rv .support
304308
305- cov_update = _linear_operators .Kronecker (
306- _linear_operators .Identity (norm_rv .shape [0 ]), constant_rv_support .T
307- )
309+ cov_update = _linear_operators .Kronecker (
310+ _linear_operators .Identity (norm_rv .shape [0 ]), constant_rv_support .T
311+ )
308312
309- # Cov(rvec(XA)) = Cov((I (x) A.T)rvec(X)) = (I (x) A.T)Cov(rvec(X))(I (x) A.T).T
310- return _Normal (
311- mean = norm_rv .mean @ constant_rv .support ,
312- cov = cov_update @ (norm_rv .cov @ cov_update .T ),
313- )
313+ # Cov(rvec(XA)) = Cov((I (x) A.T)rvec(X)) = (I (x) A.T)Cov(rvec(X))(I (x) A.T).T
314+ return _Normal (
315+ mean = norm_rv .mean @ constant_rv .support ,
316+ cov = cov_update @ (norm_rv .cov @ cov_update .T ),
317+ )
314318
315319
316320_matmul_fns [(_Normal , _Constant )] = _matmul_normal_constant
@@ -319,7 +323,8 @@ def _matmul_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal
319323def _matmul_constant_normal (constant_rv : _Constant , norm_rv : _Normal ) -> _Normal :
320324 """Matrix-multiplication with a normal random variable.
321325
322- Computes the distribution of the random variable :math:`Y = AX`, where :math:`X` is a matrix- or multi-variate normal random variable and :math:`A` a constant.
326+ Computes the distribution of the random variable :math:`Y = AX`, where :math:`X` is
327+ a matrix- or multi-variate normal random variable and :math:`A` a constant.
323328 """
324329 if norm_rv .ndim == 1 or (norm_rv .ndim == 2 and norm_rv .shape [1 ] == 1 ):
325330 if norm_rv .cov_cholesky_is_precomputed :
@@ -333,26 +338,26 @@ def _matmul_constant_normal(constant_rv: _Constant, norm_rv: _Normal) -> _Normal
333338 cov = constant_rv .support @ (norm_rv .cov @ constant_rv .support .T ),
334339 cov_cholesky = cov_cholesky ,
335340 )
341+
342+ # This part does not do the Cholesky update,
343+ # because of performance configurations: currently, there is no way of switching
344+ # the Cholesky updates off, which might affect (large, potentially sparse)
345+ # covariance matrices of matrix-variate Normal RVs. See Issue #335.
346+ if constant_rv .support .ndim == 1 :
347+ constant_rv_support = constant_rv .support [None , :]
336348 else :
337- # This part does not do the Cholesky update,
338- # because of performance configurations: currently, there is no way of switching
339- # the Cholesky updates off, which might affect (large, potentially sparse) covariance matrices
340- # of matrix-variate Normal RVs. See Issue #335.
341- if constant_rv .support .ndim == 1 :
342- constant_rv_support = constant_rv .support [None , :]
343- else :
344- constant_rv_support = constant_rv .support
349+ constant_rv_support = constant_rv .support
345350
346- cov_update = _linear_operators .Kronecker (
347- constant_rv_support ,
348- _linear_operators .Identity (norm_rv .shape [1 ]),
349- )
351+ cov_update = _linear_operators .Kronecker (
352+ constant_rv_support ,
353+ _linear_operators .Identity (norm_rv .shape [1 ]),
354+ )
350355
351- # Cov(rvec(AX)) = Cov((A (x) I)rvec(X)) = (A (x) I)Cov(rvec(X))(A (x) I).T
352- return _Normal (
353- mean = constant_rv .support @ norm_rv .mean ,
354- cov = cov_update @ (norm_rv .cov @ cov_update .T ),
355- )
356+ # Cov(rvec(AX)) = Cov((A (x) I)rvec(X)) = (A (x) I)Cov(rvec(X))(A (x) I).T
357+ return _Normal (
358+ mean = constant_rv .support @ norm_rv .mean ,
359+ cov = cov_update @ (norm_rv .cov @ cov_update .T ),
360+ )
356361
357362
358363_matmul_fns [(_Constant , _Normal )] = _matmul_constant_normal
0 commit comments