Skip to content

Commit d960cbb

Browse files
Remove scipy.stats compatibility
1 parent 08b036d commit d960cbb

File tree

3 files changed

+25
-309
lines changed

3 files changed

+25
-309
lines changed

src/probnum/randvars/__init__.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515
RandomVariable,
1616
)
1717
from ._randomvariablelist import _RandomVariableList
18-
from ._scipy_stats import (
19-
WrappedSciPyContinuousRandomVariable,
20-
WrappedSciPyDiscreteRandomVariable,
21-
WrappedSciPyRandomVariable,
22-
)
2318
from ._sym_mat_normal import SymmetricMatrixNormal
2419
from ._utils import asrandvar
2520

@@ -33,9 +28,6 @@
3328
"Normal",
3429
"SymmetricMatrixNormal",
3530
"Categorical",
36-
"WrappedSciPyRandomVariable",
37-
"WrappedSciPyDiscreteRandomVariable",
38-
"WrappedSciPyContinuousRandomVariable",
3931
"_RandomVariableList",
4032
]
4133

@@ -44,10 +36,6 @@
4436
DiscreteRandomVariable.__module__ = "probnum.randvars"
4537
ContinuousRandomVariable.__module__ = "probnum.randvars"
4638

47-
WrappedSciPyRandomVariable.__module__ = "probnum.randvars"
48-
WrappedSciPyDiscreteRandomVariable.__module__ = "probnum.randvars"
49-
WrappedSciPyContinuousRandomVariable.__module__ = "probnum.randvars"
50-
5139
Constant.__module__ = "probnum.randvars"
5240
Normal.__module__ = "probnum.randvars"
5341
SymmetricMatrixNormal.__module__ = "probnum.randvars"

src/probnum/randvars/_scipy_stats.py

Lines changed: 0 additions & 260 deletions
This file was deleted.

src/probnum/randvars/_utils.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""Utility functions for random variables."""
22
from typing import Any
33

4-
import numpy as np
54
import scipy.sparse
65

7-
import probnum.linops
6+
from probnum import backend, linops
87

9-
from . import _constant, _random_variable, _scipy_stats
8+
from . import _constant, _random_variable
109

1110

1211
def asrandvar(obj: Any) -> _random_variable.RandomVariable:
@@ -17,51 +16,40 @@ def asrandvar(obj: Any) -> _random_variable.RandomVariable:
1716
1817
Parameters
1918
----------
20-
obj :
19+
obj
2120
Object to be represented as a :class:`RandomVariable`.
2221
22+
Returns
23+
-------
24+
randvar
25+
Object as a :class:`RandomVariable`.
26+
27+
Raises
28+
------
29+
ValueError
30+
If the object cannot be represented as a :class:`RandomVariable`.
31+
2332
See Also
2433
--------
2534
RandomVariable : Class representing random variables.
26-
27-
Examples
28-
--------
29-
>>> from scipy.stats import bernoulli
30-
>>> import probnum as pn
31-
>>> import numpy as np
32-
>>> bern = bernoulli(p=0.5)
33-
>>> bern_pn = pn.asrandvar(bern)
34-
>>> rng = np.random.default_rng(42)
35-
>>> bern_pn.sample(rng=rng, size=5)
36-
array([1, 0, 1, 1, 0])
3735
"""
3836

39-
# pylint: disable=protected-access
40-
4137
# RandomVariable
4238
if isinstance(obj, _random_variable.RandomVariable):
4339
return obj
40+
4441
# Scalar
45-
elif np.isscalar(obj):
42+
if backend.ndim(obj) == 0:
4643
return _constant.Constant(support=obj)
47-
# Numpy array or sparse matrix
48-
elif isinstance(obj, (np.ndarray, scipy.sparse.spmatrix)):
44+
45+
# NumPy array or sparse matrix
46+
if isinstance(obj, (backend.ndarray, scipy.sparse.spmatrix)):
4947
return _constant.Constant(support=obj)
48+
5049
# Linear Operators
51-
elif isinstance(
52-
obj, (probnum.linops.LinearOperator, scipy.sparse.linalg.LinearOperator)
53-
):
54-
return _constant.Constant(support=probnum.linops.aslinop(obj))
55-
# Scipy random variable
56-
elif isinstance(
57-
obj,
58-
(
59-
scipy.stats._distn_infrastructure.rv_frozen,
60-
scipy.stats._multivariate.multi_rv_frozen,
61-
),
62-
):
63-
return _scipy_stats.wrap_scipy_rv(obj)
64-
else:
65-
raise ValueError(
66-
f"Argument of type {type(obj)} cannot be converted to a random variable."
67-
)
50+
if isinstance(obj, (linops.LinearOperator, scipy.sparse.linalg.LinearOperator)):
51+
return _constant.Constant(support=linops.aslinop(obj))
52+
53+
raise ValueError(
54+
f"Argument of type {type(obj)} cannot be converted to a random variable."
55+
)

0 commit comments

Comments
 (0)