Skip to content

Commit e606d8e

Browse files
A few Pylint fixes
1 parent 0260f4f commit e606d8e

File tree

5 files changed

+41
-25
lines changed

5 files changed

+41
-25
lines changed

src/probnum/randvars/_constant.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44

55
import numpy as np
66

7-
from probnum import config, linops, utils as _utils
8-
from probnum.typing import ArrayLikeGetitemArgType, ArrayType, ShapeArgType, ShapeType
7+
from probnum import backend, config, linops, utils as _utils
8+
from probnum.typing import (
9+
ArrayLikeGetitemArgType,
10+
ArrayType,
11+
SeedType,
12+
ShapeArgType,
13+
ShapeType,
14+
)
915

1016
from . import _random_variable
1117

@@ -60,10 +66,7 @@ def __init__(
6066
self,
6167
support: ArrayType,
6268
):
63-
if np.isscalar(support):
64-
support = _utils.as_numpy_scalar(support)
65-
66-
self._support = support
69+
self._support = backend.asarray(support)
6770

6871
support_floating = self._support.astype(
6972
np.promote_types(self._support.dtype, np.float_)
@@ -142,13 +145,13 @@ def transpose(self, *axes: int) -> "Constant":
142145
support=self._support.transpose(*axes),
143146
)
144147

145-
def _sample(self, rng: np.random.Generator, size: ShapeArgType = ()) -> ArrayType:
146-
size = _utils.as_shape(size)
148+
def _sample(self, seed: SeedType, sample_shape: ShapeArgType = ()) -> ArrayType:
149+
# pylint: disable=unused-argument
147150

148-
if size == ():
151+
if sample_shape == ():
149152
return self._support.copy()
150-
else:
151-
return np.tile(self._support, reps=size + (1,) * self.ndim)
153+
154+
return np.tile(self._support, reps=sample_shape + (1,) * self.ndim)
152155

153156
# Unary arithmetic operations
154157

src/probnum/randvars/_normal.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,10 +460,11 @@ def _sample(self, seed: SeedLike, sample_shape: ShapeType = ()) -> ArrayType:
460460
def _arg_todense(x: Union[ArrayType, linops.LinearOperator]) -> ArrayType:
461461
if isinstance(x, linops.LinearOperator):
462462
return x.todense()
463-
elif isinstance(x, backend.ndarray):
463+
464+
if isinstance(x, backend.ndarray):
464465
return x
465-
else:
466-
raise ValueError(f"Unsupported argument type {type(x)}")
466+
467+
raise ValueError(f"Unsupported argument type {type(x)}")
467468

468469
@backend.jit_method
469470
def _in_support(self, x: ArrayType) -> ArrayType:
@@ -504,7 +505,7 @@ def _cdf(self, x: ArrayType) -> ArrayType:
504505
if backend.BACKEND is not backend.Backend.NUMPY:
505506
raise NotImplementedError()
506507

507-
import scipy.stats
508+
import scipy.stats # pylint: disable=import-outside-toplevel
508509

509510
return scipy.stats.multivariate_normal.cdf(
510511
Normal._arg_todense(x).reshape(x.shape[: -self.ndim] + (-1,)),
@@ -516,7 +517,7 @@ def _logcdf(self, x: ArrayType) -> ArrayType:
516517
if backend.BACKEND is not backend.Backend.NUMPY:
517518
raise NotImplementedError()
518519

519-
import scipy.stats
520+
import scipy.stats # pylint: disable=import-outside-toplevel
520521

521522
return scipy.stats.multivariate_normal.logcdf(
522523
Normal._arg_todense(x).reshape(x.shape[: -self.ndim] + (-1,)),

src/probnum/randvars/_random_variable.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
ArrayLikeGetitemArgType,
1212
ArrayType,
1313
DTypeArgType,
14-
FloatArgType,
1514
ScalarType,
1615
SeedType,
1716
ShapeArgType,
@@ -337,7 +336,8 @@ def var(self) -> ArrayType:
337336
def std(self) -> ArrayType:
338337
"""Standard deviation of the random variable.
339338
340-
To learn about the dtype of the standard deviation, see :attr:`expectation_dtype`.
339+
To learn about the dtype of the standard deviation, see
340+
:attr:`expectation_dtype`.
341341
"""
342342
if self.__std is None:
343343
std = backend.sqrt(self.var)
@@ -578,7 +578,9 @@ def __neg__(self) -> "RandomVariable":
578578
return RandomVariable(
579579
shape=self.shape,
580580
dtype=self.dtype,
581-
sample=lambda rng, size: -self.sample(rng=rng, size=size),
581+
sample=lambda seed, sample_shape: -self.sample(
582+
seed=seed, sample_shape=sample_shape
583+
),
582584
in_support=lambda x: self.in_support(-x),
583585
mode=lambda: -self.mode,
584586
median=lambda: -self.median,
@@ -592,7 +594,9 @@ def __pos__(self) -> "RandomVariable":
592594
return RandomVariable(
593595
shape=self.shape,
594596
dtype=self.dtype,
595-
sample=lambda rng, size: +self.sample(rng=rng, size=size),
597+
sample=lambda seed, sample_shape: +self.sample(
598+
seed=seed, sample_shape=sample_shape
599+
),
596600
in_support=lambda x: self.in_support(+x),
597601
mode=lambda: +self.mode,
598602
median=lambda: +self.median,
@@ -606,7 +610,9 @@ def __abs__(self) -> "RandomVariable":
606610
return RandomVariable(
607611
shape=self.shape,
608612
dtype=self.dtype,
609-
sample=lambda rng, size: abs(self.sample(rng=rng, size=size)),
613+
sample=lambda seed, sample_shape: abs(
614+
self.sample(seed=seed, sample_shape=sample_shape)
615+
),
610616
)
611617

612618
# Binary arithmetic operations
@@ -891,6 +897,8 @@ def __init__(
891897
std: Optional[Callable[[], ArrayType]] = None,
892898
entropy: Optional[Callable[[], ScalarType]] = None,
893899
):
900+
# pylint: disable=too-many-arguments,too-many-locals
901+
894902
# Probability mass function
895903
self.__pmf = pmf
896904
self.__logpmf = logpmf
@@ -1098,6 +1106,8 @@ def __init__(
10981106
std: Optional[Callable[[], ArrayType]] = None,
10991107
entropy: Optional[Callable[[], ArrayType]] = None,
11001108
):
1109+
# pylint: disable=too-many-arguments,too-many-locals
1110+
11011111
# Probability density function
11021112
self.__pdf = pdf
11031113
self.__logpdf = logpdf

src/probnum/randvars/_randomvariablelist.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from probnum import randvars
7+
from probnum.randvars import _random_variable
88

99
try:
1010
# functools.cached_property is only available in Python >=3.8
@@ -30,14 +30,16 @@ def __init__(self, rv_list: list):
3030
if len(rv_list) > 0:
3131

3232
# First element as a proxy for checking all elements
33-
if not isinstance(rv_list[0], randvars.RandomVariable):
33+
if not isinstance(rv_list[0], _random_variable.RandomVariable):
3434
raise TypeError(
3535
"RandomVariableList expects RandomVariable elements, but "
3636
+ f"first element has type {type(rv_list[0])}."
3737
)
3838
super().__init__(rv_list)
3939

40-
def __getitem__(self, idx) -> Union[randvars.RandomVariable, "_RandomVariableList"]:
40+
def __getitem__(
41+
self, idx
42+
) -> Union[_random_variable.RandomVariable, "_RandomVariableList"]:
4143

4244
result = super().__getitem__(idx)
4345
# Make sure to wrap the result into a _RandomVariableList if necessary

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ commands =
7979
pylint src/probnum/quad --disable="function-redefined,too-many-arguments,else-if-used,line-too-long,missing-module-docstring,missing-function-docstring,missing-raises-doc,missing-return-doc" --jobs=0
8080
pylint src/probnum/randprocs --disable="arguments-differ,arguments-renamed,too-many-instance-attributes,too-many-arguments,too-many-locals,protected-access,unused-argument,no-else-return,duplicate-code,line-too-long,missing-module-docstring,missing-function-docstring,missing-type-doc,missing-raises-doc,useless-param-doc,useless-type-doc,missing-return-doc,missing-return-type-doc" --jobs=0
8181
pylint src/probnum/randprocs/kernels --jobs=0
82-
pylint src/probnum/randvars --disable="too-many-arguments,too-many-locals,too-many-branches,too-few-public-methods,protected-access,unused-argument,no-else-return,duplicate-code,line-too-long,missing-function-docstring,missing-raises-doc,missing-return-doc" --jobs=0
82+
pylint src/probnum/randvars --disable="missing-function-docstring,missing-raises-doc,missing-return-doc" --jobs=0
8383
pylint src/probnum/utils --disable="no-else-return,else-if-used,line-too-long,missing-raises-doc,missing-return-doc,missing-return-type-doc" --jobs=0
8484
# Benchmark and Test Code Linting Pass
8585
# pylint benchmarks --disable="unused-argument,attribute-defined-outside-init,missing-function-docstring" --jobs=0 # not a work in progress, but final

0 commit comments

Comments
 (0)