11"""Utility functions for random variables."""
22from typing import Any
33
4- import numpy as np
54import scipy .sparse
65
7- import probnum . linops
6+ from probnum import backend , linops
87
9- from . import _constant , _random_variable , _scipy_stats
8+ from . import _constant , _random_variable
109
1110
1211def asrandvar (obj : Any ) -> _random_variable .RandomVariable :
@@ -17,51 +16,40 @@ def asrandvar(obj: Any) -> _random_variable.RandomVariable:
1716
1817 Parameters
1918 ----------
20- obj :
19+ obj
2120 Object to be represented as a :class:`RandomVariable`.
2221
22+ Returns
23+ -------
24+ randvar
25+ Object as a :class:`RandomVariable`.
26+
27+ Raises
28+ ------
29+ ValueError
30+ If the object cannot be represented as a :class:`RandomVariable`.
31+
2332 See Also
2433 --------
2534 RandomVariable : Class representing random variables.
26-
27- Examples
28- --------
29- >>> from scipy.stats import bernoulli
30- >>> import probnum as pn
31- >>> import numpy as np
32- >>> bern = bernoulli(p=0.5)
33- >>> bern_pn = pn.asrandvar(bern)
34- >>> rng = np.random.default_rng(42)
35- >>> bern_pn.sample(rng=rng, size=5)
36- array([1, 0, 1, 1, 0])
3735 """
3836
39- # pylint: disable=protected-access
40-
4137 # RandomVariable
4238 if isinstance (obj , _random_variable .RandomVariable ):
4339 return obj
40+
4441 # Scalar
45- elif np . isscalar (obj ):
42+ if backend . ndim (obj ) == 0 :
4643 return _constant .Constant (support = obj )
47- # Numpy array or sparse matrix
48- elif isinstance (obj , (np .ndarray , scipy .sparse .spmatrix )):
44+
45+ # NumPy array or sparse matrix
46+ if isinstance (obj , (backend .ndarray , scipy .sparse .spmatrix )):
4947 return _constant .Constant (support = obj )
48+
5049 # Linear Operators
51- elif isinstance (
52- obj , (probnum .linops .LinearOperator , scipy .sparse .linalg .LinearOperator )
53- ):
54- return _constant .Constant (support = probnum .linops .aslinop (obj ))
55- # Scipy random variable
56- elif isinstance (
57- obj ,
58- (
59- scipy .stats ._distn_infrastructure .rv_frozen ,
60- scipy .stats ._multivariate .multi_rv_frozen ,
61- ),
62- ):
63- return _scipy_stats .wrap_scipy_rv (obj )
64- else :
65- raise ValueError (
66- f"Argument of type { type (obj )} cannot be converted to a random variable."
67- )
50+ if isinstance (obj , (linops .LinearOperator , scipy .sparse .linalg .LinearOperator )):
51+ return _constant .Constant (support = linops .aslinop (obj ))
52+
53+ raise ValueError (
54+ f"Argument of type { type (obj )} cannot be converted to a random variable."
55+ )
0 commit comments