Skip to content

Commit 10789bf

Browse files
Remove as_value_type and exchange with asarray
1 parent 6e8d053 commit 10789bf

File tree

1 file changed

+13
-81
lines changed

1 file changed

+13
-81
lines changed

src/probnum/randvars/_random_variable.py

Lines changed: 13 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,6 @@ class RandomVariable:
7373
(Element-wise) standard deviation of the random variable.
7474
entropy :
7575
Information-theoretic entropy :math:`H(X)` of the random variable.
76-
as_value_type :
77-
Function which can be used to transform user-supplied arguments, interpreted as
78-
realizations of this random variable, to an easy-to-process, normalized format.
79-
Will be called internally to transform the argument of functions like
80-
:meth:`~RandomVariable.in_support`, :meth:`~RandomVariable.cdf`
81-
and :meth:`~RandomVariable.logcdf`, :meth:`~DiscreteRandomVariable.pmf`
82-
and :meth:`~DiscreteRandomVariable.logpmf` (in :class:`DiscreteRandomVariable`),
83-
:meth:`~ContinuousRandomVariable.pdf` and
84-
:meth:`~ContinuousRandomVariable.logpdf` (in :class:`ContinuousRandomVariable`),
85-
and potentially by similar functions in subclasses.
86-
87-
For instance, this method is useful if (``log``)
88-
:meth:`~ContinousRandomVariable.cdf` and (``log``)
89-
:meth:`~ContinuousRandomVariable.pdf` both only work on :class:`numpy.float_`
90-
arguments, but we still want the user to be able to pass Python
91-
:class:`float`. Then :meth:`~RandomVariable.as_value_type`
92-
should be set to something like ``lambda x: np.float64(x)``.
9376
9477
See Also
9578
--------
@@ -133,7 +116,6 @@ def __init__(
133116
var: Optional[Callable[[], ArrayType]] = None,
134117
std: Optional[Callable[[], ArrayType]] = None,
135118
entropy: Optional[Callable[[], ScalarType]] = None,
136-
as_value_type: Optional[Callable[[Any], ArrayType]] = None,
137119
):
138120
# pylint: disable=too-many-arguments,too-many-locals
139121
"""Create a new random variable."""
@@ -161,9 +143,6 @@ def __init__(
161143
self.__std = std
162144
self.__entropy = entropy
163145

164-
# Utilities
165-
self.__as_value_type = as_value_type
166-
167146
def __repr__(self) -> str:
168147
return (
169148
f"<{self.__class__.__name__} with shape={self.shape}, dtype"
@@ -407,7 +386,7 @@ def in_support(self, x: ArrayType) -> ArrayType:
407386
if self.__in_support is None:
408387
raise NotImplementedError
409388

410-
in_support = self.__in_support(self._as_value_type(x))
389+
in_support = self.__in_support(backend.asarray(x))
411390

412391
self._check_return_value(
413392
"in_support",
@@ -450,9 +429,9 @@ def cdf(self, x: ArrayType) -> ArrayType:
450429
The cdf evaluation will be broadcast over all additional dimensions.
451430
"""
452431
if self.__cdf is not None:
453-
cdf = self.__cdf(self._as_value_type(x))
432+
cdf = self.__cdf(backend.asarray(x))
454433
elif self.__logcdf is not None:
455-
cdf = backend.exp(self.logcdf(self._as_value_type(x)))
434+
cdf = backend.exp(self.logcdf(x))
456435
else:
457436
raise NotImplementedError(
458437
f"Neither the `cdf` nor the `logcdf` of the random variable object "
@@ -481,9 +460,9 @@ def logcdf(self, x: ArrayType) -> ArrayType:
481460
The logcdf evaluation will be broadcast over all additional dimensions.
482461
"""
483462
if self.__logcdf is not None:
484-
logcdf = self.__logcdf(self._as_value_type(x))
463+
logcdf = self.__logcdf(backend.asarray(x))
485464
elif self.__cdf is not None:
486-
logcdf = backend.log(self.__cdf(x))
465+
logcdf = backend.log(self.cdf(x))
487466
else:
488467
raise NotImplementedError(
489468
f"Neither the `logcdf` nor the `cdf` of the random variable object "
@@ -544,7 +523,6 @@ def __getitem__(self, key: ArrayLikeGetitemArgType) -> "RandomVariable":
544523
var=lambda: self.var[key],
545524
std=lambda: self.std[key],
546525
entropy=lambda: self.entropy,
547-
as_value_type=self.__as_value_type,
548526
)
549527

550528
def reshape(self, newshape: ShapeArgType) -> "RandomVariable":
@@ -569,7 +547,6 @@ def reshape(self, newshape: ShapeArgType) -> "RandomVariable":
569547
var=lambda: self.var.reshape(newshape),
570548
std=lambda: self.std.reshape(newshape),
571549
entropy=lambda: self.entropy,
572-
as_value_type=self.__as_value_type,
573550
)
574551

575552
def transpose(self, *axes: int) -> "RandomVariable":
@@ -591,7 +568,6 @@ def transpose(self, *axes: int) -> "RandomVariable":
591568
var=lambda: self.var.transpose(*axes),
592569
std=lambda: self.std.transpose(*axes),
593570
entropy=lambda: self.entropy,
594-
as_value_type=self.__as_value_type,
595571
)
596572

597573
T = property(transpose)
@@ -610,7 +586,6 @@ def __neg__(self) -> "RandomVariable":
610586
cov=lambda: self.cov,
611587
var=lambda: self.var,
612588
std=lambda: self.std,
613-
as_value_type=self.__as_value_type,
614589
)
615590

616591
def __pos__(self) -> "RandomVariable":
@@ -625,7 +600,6 @@ def __pos__(self) -> "RandomVariable":
625600
cov=lambda: self.cov,
626601
var=lambda: self.var,
627602
std=lambda: self.std,
628-
as_value_type=self.__as_value_type,
629603
)
630604

631605
def __abs__(self) -> "RandomVariable":
@@ -754,12 +728,6 @@ def __rpow__(self, other: Any) -> "RandomVariable":
754728

755729
return pow_(other, self)
756730

757-
def _as_value_type(self, x: Any) -> ArrayType:
758-
if self.__as_value_type is not None:
759-
return self.__as_value_type(x)
760-
761-
return x
762-
763731
@staticmethod
764732
def _check_property_value(
765733
name: str,
@@ -851,21 +819,6 @@ class DiscreteRandomVariable(RandomVariable):
851819
(Element-wise) standard deviation of the random variable.
852820
entropy :
853821
Shannon entropy :math:`H(X)` of the random variable.
854-
as_value_type :
855-
Function which can be used to transform user-supplied arguments, interpreted as
856-
realizations of this random variable, to an easy-to-process, normalized format.
857-
Will be called internally to transform the argument of functions like
858-
:meth:`~DiscreteRandomVariable.in_support`, :meth:`~DiscreteRandomVariable.cdf`
859-
and :meth:`~DiscreteRandomVariable.logcdf`, :meth:`~DiscreteRandomVariable.pmf`
860-
and :meth:`~DiscreteRandomVariable.logpmf`, and potentially by similar
861-
functions in subclasses.
862-
863-
For instance, this method is useful if (``log``)
864-
:meth:`~DiscreteRandomVariable.cdf` and (``log``)
865-
:meth:`~DiscreteRandomVariable.pmf` both only work on :class:`numpy.float_`
866-
arguments, but we still want the user to be able to pass Python
867-
:class:`float`. Then :meth:`~DiscreteRandomVariable.as_value_type`
868-
should be set to something like ``lambda x: np.float64(x)``.
869822
870823
See Also
871824
--------
@@ -937,7 +890,6 @@ def __init__(
937890
var: Optional[Callable[[], ArrayType]] = None,
938891
std: Optional[Callable[[], ArrayType]] = None,
939892
entropy: Optional[Callable[[], ScalarType]] = None,
940-
as_value_type: Optional[Callable[[Any], ArrayType]] = None,
941893
):
942894
# Probability mass function
943895
self.__pmf = pmf
@@ -959,7 +911,6 @@ def __init__(
959911
var=var,
960912
std=std,
961913
entropy=entropy,
962-
as_value_type=as_value_type,
963914
)
964915

965916
def pmf(self, x: ArrayType) -> ArrayType:
@@ -983,9 +934,9 @@ def pmf(self, x: ArrayType) -> ArrayType:
983934
The pmf evaluation will be broadcast over all additional dimensions.
984935
"""
985936
if self.__pmf is not None:
986-
pmf = self.__pmf(x)
937+
pmf = self.__pmf(backend.asarray(x))
987938
elif self.__logpmf is not None:
988-
pmf = backend.exp(self.__logpmf(x))
939+
pmf = backend.exp(self.logpmf(x))
989940
else:
990941
raise NotImplementedError(
991942
f"Neither the `pmf` nor the `logpmf` of the discrete random variable "
@@ -1014,9 +965,9 @@ def logpmf(self, x: ArrayType) -> ArrayType:
1014965
The logpmf evaluation will be broadcast over all additional dimensions.
1015966
"""
1016967
if self.__logpmf is not None:
1017-
logpmf = self.__logpmf(self._as_value_type(x))
968+
logpmf = self.__logpmf(backend.asarray(x))
1018969
elif self.__pmf is not None:
1019-
logpmf = backend.log(self.__pmf(self._as_value_type(x)))
970+
logpmf = backend.log(self.pmf(x))
1020971
else:
1021972
raise NotImplementedError(
1022973
f"Neither the `logpmf` nor the `pmf` of the discrete random variable "
@@ -1077,23 +1028,6 @@ class ContinuousRandomVariable(RandomVariable):
10771028
(Element-wise) standard deviation of the random variable.
10781029
entropy :
10791030
Differential entropy :math:`H(X)` of the random variable.
1080-
as_value_type :
1081-
Function which can be used to transform user-supplied arguments, interpreted as
1082-
realizations of this random variable, to an easy-to-process, normalized format.
1083-
Will be called internally to transform the argument of functions like
1084-
:meth:`~ContinuousRandomVariable.in_support`,
1085-
:meth:`~ContinuousRandomVariable.cdf`
1086-
and :meth:`~ContinuousRandomVariable.logcdf`,
1087-
:meth:`~ContinuousRandomVariable.pdf` and
1088-
:meth:`~ContinuousRandomVariable.logpdf`, and potentially by similar
1089-
functions in subclasses.
1090-
1091-
For instance, this method is useful if (``log``)
1092-
:meth:`~ContinuousRandomVariable.cdf` and (``log``)
1093-
:meth:`~ContinuousRandomVariable.pdf` both only work on :class:`numpy.float_`
1094-
arguments, but we still want the user to be able to pass Python
1095-
:class:`float`. Then :meth:`~ContinuousRandomVariable.as_value_type`
1096-
should be set to something like ``lambda x: np.float64(x)``.
10971031
10981032
See Also
10991033
--------
@@ -1163,7 +1097,6 @@ def __init__(
11631097
var: Optional[Callable[[], ArrayType]] = None,
11641098
std: Optional[Callable[[], ArrayType]] = None,
11651099
entropy: Optional[Callable[[], ArrayType]] = None,
1166-
as_value_type: Optional[Callable[[Any], ArrayType]] = None,
11671100
):
11681101
# Probability density function
11691102
self.__pdf = pdf
@@ -1185,7 +1118,6 @@ def __init__(
11851118
var=var,
11861119
std=std,
11871120
entropy=entropy,
1188-
as_value_type=as_value_type,
11891121
)
11901122

11911123
def pdf(self, x: ArrayType) -> ArrayType:
@@ -1209,9 +1141,9 @@ def pdf(self, x: ArrayType) -> ArrayType:
12091141
The pdf evaluation will be broadcast over all additional dimensions.
12101142
"""
12111143
if self.__pdf is not None:
1212-
pdf = self.__pdf(self._as_value_type(x))
1144+
pdf = self.__pdf(backend.asarray(x))
12131145
elif self.__logpdf is not None:
1214-
pdf = backend.exp(self.__logpdf(self._as_value_type(x)))
1146+
pdf = backend.exp(self.logpdf(x))
12151147
else:
12161148
raise NotImplementedError(
12171149
f"Neither the `pdf` nor the `logpdf` of the continuous random variable "
@@ -1240,9 +1172,9 @@ def logpdf(self, x: ArrayType) -> ArrayType:
12401172
The logpdf evaluation will be broadcast over all additional dimensions.
12411173
"""
12421174
if self.__logpdf is not None:
1243-
logpdf = self.__logpdf(self._as_value_type(x))
1175+
logpdf = self.__logpdf(backend.asarray(x))
12441176
elif self.__pdf is not None:
1245-
logpdf = backend.log(self.__pdf(self._as_value_type(x)))
1177+
logpdf = backend.log(self.pdf(x))
12461178
else:
12471179
raise NotImplementedError(
12481180
f"Neither the `logpdf` nor the `pdf` of the continuous random variable "

0 commit comments

Comments
 (0)