Skip to content

Commit 67fe61d

Browse files
Fix most pylint messages in randvars
1 parent 0b9ad27 commit 67fe61d

File tree

4 files changed

+112
-85
lines changed

4 files changed

+112
-85
lines changed

src/probnum/randvars/_arithmetic.py

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -180,29 +180,32 @@ def _generic_rv_add(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariab
180180
# Constant - Constant Arithmetic
181181
########################################################################################
182182

183-
_add_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.add)
184-
_sub_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.sub)
185-
_mul_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.mul)
186-
_matmul_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(
183+
_constant_constant_operator_factory = (
184+
_Constant._binary_operator_factory # pylint: disable=protected-access
185+
)
186+
187+
_add_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.add)
188+
_sub_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.sub)
189+
_mul_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.mul)
190+
_matmul_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(
187191
operator.matmul
188192
)
189-
_truediv_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(
193+
_truediv_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(
190194
operator.truediv
191195
)
192-
_floordiv_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(
196+
_floordiv_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(
193197
operator.floordiv
194198
)
195-
_mod_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.mod)
196-
_divmod_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(divmod)
197-
_pow_fns[(_Constant, _Constant)] = _Constant._binary_operator_factory(operator.pow)
199+
_mod_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.mod)
200+
_divmod_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(divmod)
201+
_pow_fns[(_Constant, _Constant)] = _constant_constant_operator_factory(operator.pow)
198202

199203
########################################################################################
200204
# Normal - Normal Arithmetic
201205
########################################################################################
202206

203-
_add_fns[(_Normal, _Normal)] = _Normal._add_normal
204-
_sub_fns[(_Normal, _Normal)] = _Normal._sub_normal
205-
207+
_add_fns[(_Normal, _Normal)] = _Normal._add_normal # pylint: disable=protected-access
208+
_sub_fns[(_Normal, _Normal)] = _Normal._sub_normal # pylint: disable=protected-access
206209

207210
########################################################################################
208211
# Normal - Constant Arithmetic
@@ -254,16 +257,16 @@ def _mul_normal_constant(
254257
return _Constant(
255258
support=backend.zeros_like(norm_rv.mean),
256259
)
260+
261+
if norm_rv.cov_cholesky_is_precomputed:
262+
cov_cholesky = constant_rv.support * norm_rv.cov_cholesky
257263
else:
258-
if norm_rv.cov_cholesky_is_precomputed:
259-
cov_cholesky = constant_rv.support * norm_rv.cov_cholesky
260-
else:
261-
cov_cholesky = None
262-
return _Normal(
263-
mean=constant_rv.support * norm_rv.mean,
264-
cov=(constant_rv.support ** 2) * norm_rv.cov,
265-
cov_cholesky=cov_cholesky,
266-
)
264+
cov_cholesky = None
265+
return _Normal(
266+
mean=constant_rv.support * norm_rv.mean,
267+
cov=(constant_rv.support ** 2) * norm_rv.cov,
268+
cov_cholesky=cov_cholesky,
269+
)
267270

268271
return NotImplemented
269272

@@ -275,7 +278,8 @@ def _mul_normal_constant(
275278
def _matmul_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal:
276279
"""Normal random variable multiplied with a vector or matrix.
277280
278-
Computes the distribution of the random variable :math:`Y = XA`, where :math:`X` is a matrix- or multi-variate normal random variable and :math:`A` a constant.
281+
Computes the distribution of the random variable :math:`Y = XA`, where :math:`X` is
282+
a matrix- or multi-variate normal random variable and :math:`A` a constant.
279283
"""
280284
if norm_rv.ndim == 1 or (norm_rv.ndim == 2 and norm_rv.shape[0] == 1):
281285
if norm_rv.cov_cholesky_is_precomputed:
@@ -292,25 +296,25 @@ def _matmul_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal
292296
cov = cov.reshape((1, 1))
293297

294298
return _Normal(mean=mean, cov=cov, cov_cholesky=cov_cholesky)
299+
300+
# This part does not do the Cholesky update,
301+
# because of performance configurations: currently, there is no way of switching
302+
# the Cholesky updates off, which might affect (large, potentially sparse)
303+
# covariance matrices of matrix-variate Normal RVs. See Issue #335.
304+
if constant_rv.support.ndim == 1:
305+
constant_rv_support = constant_rv.support[:, None]
295306
else:
296-
# This part does not do the Cholesky update,
297-
# because of performance configurations: currently, there is no way of switching
298-
# the Cholesky updates off, which might affect (large, potentially sparse) covariance matrices
299-
# of matrix-variate Normal RVs. See Issue #335.
300-
if constant_rv.support.ndim == 1:
301-
constant_rv_support = constant_rv.support[:, None]
302-
else:
303-
constant_rv_support = constant_rv.support
307+
constant_rv_support = constant_rv.support
304308

305-
cov_update = _linear_operators.Kronecker(
306-
_linear_operators.Identity(norm_rv.shape[0]), constant_rv_support.T
307-
)
309+
cov_update = _linear_operators.Kronecker(
310+
_linear_operators.Identity(norm_rv.shape[0]), constant_rv_support.T
311+
)
308312

309-
# Cov(rvec(XA)) = Cov((I (x) A.T)rvec(X)) = (I (x) A.T)Cov(rvec(X))(I (x) A.T).T
310-
return _Normal(
311-
mean=norm_rv.mean @ constant_rv.support,
312-
cov=cov_update @ (norm_rv.cov @ cov_update.T),
313-
)
313+
# Cov(rvec(XA)) = Cov((I (x) A.T)rvec(X)) = (I (x) A.T)Cov(rvec(X))(I (x) A.T).T
314+
return _Normal(
315+
mean=norm_rv.mean @ constant_rv.support,
316+
cov=cov_update @ (norm_rv.cov @ cov_update.T),
317+
)
314318

315319

316320
_matmul_fns[(_Normal, _Constant)] = _matmul_normal_constant
@@ -319,7 +323,8 @@ def _matmul_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal
319323
def _matmul_constant_normal(constant_rv: _Constant, norm_rv: _Normal) -> _Normal:
320324
"""Matrix-multiplication with a normal random variable.
321325
322-
Computes the distribution of the random variable :math:`Y = AX`, where :math:`X` is a matrix- or multi-variate normal random variable and :math:`A` a constant.
326+
Computes the distribution of the random variable :math:`Y = AX`, where :math:`X` is
327+
a matrix- or multi-variate normal random variable and :math:`A` a constant.
323328
"""
324329
if norm_rv.ndim == 1 or (norm_rv.ndim == 2 and norm_rv.shape[1] == 1):
325330
if norm_rv.cov_cholesky_is_precomputed:
@@ -333,26 +338,26 @@ def _matmul_constant_normal(constant_rv: _Constant, norm_rv: _Normal) -> _Normal
333338
cov=constant_rv.support @ (norm_rv.cov @ constant_rv.support.T),
334339
cov_cholesky=cov_cholesky,
335340
)
341+
342+
# This part does not do the Cholesky update,
343+
# because of performance configurations: currently, there is no way of switching
344+
# the Cholesky updates off, which might affect (large, potentially sparse)
345+
# covariance matrices of matrix-variate Normal RVs. See Issue #335.
346+
if constant_rv.support.ndim == 1:
347+
constant_rv_support = constant_rv.support[None, :]
336348
else:
337-
# This part does not do the Cholesky update,
338-
# because of performance configurations: currently, there is no way of switching
339-
# the Cholesky updates off, which might affect (large, potentially sparse) covariance matrices
340-
# of matrix-variate Normal RVs. See Issue #335.
341-
if constant_rv.support.ndim == 1:
342-
constant_rv_support = constant_rv.support[None, :]
343-
else:
344-
constant_rv_support = constant_rv.support
349+
constant_rv_support = constant_rv.support
345350

346-
cov_update = _linear_operators.Kronecker(
347-
constant_rv_support,
348-
_linear_operators.Identity(norm_rv.shape[1]),
349-
)
351+
cov_update = _linear_operators.Kronecker(
352+
constant_rv_support,
353+
_linear_operators.Identity(norm_rv.shape[1]),
354+
)
350355

351-
# Cov(rvec(AX)) = Cov((A (x) I)rvec(X)) = (A (x) I)Cov(rvec(X))(A (x) I).T
352-
return _Normal(
353-
mean=constant_rv.support @ norm_rv.mean,
354-
cov=cov_update @ (norm_rv.cov @ cov_update.T),
355-
)
356+
# Cov(rvec(AX)) = Cov((A (x) I)rvec(X)) = (A (x) I)Cov(rvec(X))(A (x) I).T
357+
return _Normal(
358+
mean=constant_rv.support @ norm_rv.mean,
359+
cov=cov_update @ (norm_rv.cov @ cov_update.T),
360+
)
356361

357362

358363
_matmul_fns[(_Constant, _Normal)] = _matmul_constant_normal

src/probnum/randvars/_categorical.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import numpy as np
55

6+
from probnum import backend
7+
from probnum.typing import SeedType, ShapeType
8+
69
from ._random_variable import DiscreteRandomVariable
710

811

@@ -24,6 +27,12 @@ def __init__(
2427
probabilities: np.ndarray,
2528
support: Optional[np.ndarray] = None,
2629
):
30+
if backend.BACKEND != backend.Backend.NUMPY:
31+
raise NotImplementedError(
32+
"The `Categorical` random variable only supports the `numpy` backend "
33+
"at the moment."
34+
)
35+
2736
# The set of events is names "support" to be aligned with the method
2837
# DiscreteRandomVariable.in_support().
2938

@@ -39,7 +48,9 @@ def __init__(
3948
"num_categories": num_categories,
4049
}
4150

42-
def _sample_categorical(rng, size=()):
51+
def _sample_categorical(
52+
seed: np.random.SeedSequence, sample_shape: ShapeType = ()
53+
):
4354
"""Sample from a categorical distribution.
4455
4556
While on first sight, one might think that this
@@ -49,10 +60,12 @@ def _sample_categorical(rng, size=()):
4960
arrays with `ndim > 1`, but `self.support` can be just that.
5061
This detour via the `mask` avoids this problem.
5162
"""
52-
63+
rng = np.random.default_rng(seed)
5364
indices = rng.choice(
54-
np.arange(len(self.support)), size=size, p=self.probabilities
55-
).reshape(size)
65+
np.arange(len(self.support)),
66+
size=sample_shape,
67+
p=self.probabilities,
68+
).reshape(sample_shape)
5669
return self.support[indices]
5770

5871
def _pmf_categorical(x):
@@ -64,7 +77,8 @@ def _pmf_categorical(x):
6477
x = np.asarray(x)
6578
if x.dtype != self.dtype:
6679
raise ValueError(
67-
"The data type of x does not match with the data type of the support."
80+
"The data type of x does not match with the data type of the "
81+
"support."
6882
)
6983

7084
mask = (x == self.support).nonzero()[0]
@@ -93,7 +107,7 @@ def support(self) -> np.ndarray:
93107
"""Support of the categorical distribution."""
94108
return self._support
95109

96-
def resample(self, rng: np.random.Generator) -> "Categorical":
110+
def resample(self, seed: SeedType) -> "Categorical":
97111
"""Resample the support of the categorical random variable.
98112
99113
Return a new categorical random variable (RV), where the support
@@ -103,16 +117,17 @@ def resample(self, rng: np.random.Generator) -> "Categorical":
103117
104118
Parameters
105119
----------
106-
rng :
107-
Random number generator.
120+
seed
121+
Seed for random number generation
108122
109123
Returns
110124
-------
111125
Categorical
112-
Categorical random variable with resampled support (according to self.probabilities).
126+
Categorical random variable with resampled support (according to
127+
self.probabilities).
113128
"""
114129
num_events = len(self.support)
115-
new_support = self.sample(rng=rng, size=num_events)
130+
new_support = self.sample(seed, sample_shape=num_events)
116131
new_probabilities = np.ones(self.probabilities.shape) / num_events
117132
return Categorical(
118133
support=new_support,

src/probnum/randvars/_random_variable.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def in_support(self, x: ArrayType) -> ArrayType:
390390

391391
self._check_return_value(
392392
"in_support",
393-
input=x,
393+
input_value=x,
394394
return_value=in_support,
395395
expected_shape=x.shape[: -self.ndim],
396396
expected_dtype=backend.bool,
@@ -440,7 +440,7 @@ def cdf(self, x: ArrayType) -> ArrayType:
440440

441441
self._check_return_value(
442442
"cdf",
443-
input=x,
443+
input_value=x,
444444
return_value=cdf,
445445
expected_shape=x.shape[: -self.ndim],
446446
expected_dtype=backend.double,
@@ -471,7 +471,7 @@ def logcdf(self, x: ArrayType) -> ArrayType:
471471

472472
self._check_return_value(
473473
"logcdf",
474-
input=x,
474+
input_value=x,
475475
return_value=logcdf,
476476
expected_shape=x.shape[: -self.ndim],
477477
expected_dtype=backend.double,
@@ -505,7 +505,7 @@ def quantile(self, p: ArrayType) -> ArrayType:
505505

506506
self._check_return_value(
507507
"quantile",
508-
input=p,
508+
input_value=p,
509509
return_value=quantile,
510510
expected_shape=p.shape + self.shape,
511511
expected_dtype=self.dtype,
@@ -758,27 +758,30 @@ def _check_property_value(
758758
def _check_return_value(
759759
self,
760760
method_name: str,
761-
input: ArrayType,
761+
input_value: ArrayType,
762762
return_value: ArrayType,
763763
expected_shape: Optional[ShapeType] = None,
764764
expected_dtype: Optional[backend.dtype] = None,
765765
):
766+
# pylint: disable=too-many-arguments
767+
766768
if expected_shape is not None:
767769
if return_value.shape != expected_shape:
768770
raise ValueError(
769771
f"The return value of the function `{method_name}` does not have "
770-
f"the correct shape for an input with shape {input.shape} and a "
771-
f"random variable with shape {self.shape}. Expected "
772+
f"the correct shape for an input with shape {input_value.shape} "
773+
f"and a random variable with shape {self.shape}. Expected "
772774
f"{expected_shape} but got {return_value.shape}."
773775
)
774776

775777
if expected_dtype is not None:
776778
if return_value.dtype != expected_dtype:
777779
raise ValueError(
778780
f"The return value of the function `{method_name}` does not have "
779-
f"the correct dtype for an input with dtype {str(input.dtype)} and "
780-
f"a random variable with dtype {str(self.dtype)}. Expexted "
781-
f"{str(expected_dtype)} but got {str(return_value.dtype)}."
781+
f"the correct dtype for an input with dtype "
782+
f"{str(input_value.dtype)} and a random variable with dtype "
783+
f"{str(self.dtype)}. Expected {str(expected_dtype)} but got "
784+
f"{str(return_value.dtype)}."
782785
)
783786

784787

@@ -953,7 +956,7 @@ def pmf(self, x: ArrayType) -> ArrayType:
953956

954957
self._check_return_value(
955958
"pmf",
956-
input=x,
959+
input_value=x,
957960
return_value=pmf,
958961
expected_shape=x.shape[: -self.ndim],
959962
expected_dtype=backend.double,
@@ -984,7 +987,7 @@ def logpmf(self, x: ArrayType) -> ArrayType:
984987

985988
self._check_return_value(
986989
"logpmf",
987-
input=x,
990+
input_value=x,
988991
return_value=logpmf,
989992
expected_shape=x.shape[: -self.ndim],
990993
expected_dtype=backend.double,
@@ -1162,7 +1165,7 @@ def pdf(self, x: ArrayType) -> ArrayType:
11621165

11631166
self._check_return_value(
11641167
"pdf",
1165-
input=x,
1168+
input_value=x,
11661169
return_value=pdf,
11671170
expected_shape=x.shape[: -self.ndim],
11681171
expected_dtype=backend.double,
@@ -1193,7 +1196,7 @@ def logpdf(self, x: ArrayType) -> ArrayType:
11931196

11941197
self._check_return_value(
11951198
"logpdf",
1196-
input=x,
1199+
input_value=x,
11971200
return_value=logpdf,
11981201
expected_shape=x.shape[: -self.ndim],
11991202
expected_dtype=backend.double,

0 commit comments

Comments
 (0)