Skip to content

Commit 7110791

Browse files
Initialized tests for the new normal RV
1 parent 1ecbd87 commit 7110791

File tree

6 files changed

+44
-4
lines changed

6 files changed

+44
-4
lines changed

src/probnum/backend/_core/_torch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434

3535

3636
def asdtype(x) -> torch.dtype:
37-
# Parse `x` with NumPy and convert `np.dtype`` into `torch.dtype`
37+
if isinstance(x, torch.dtype):
38+
return x
39+
3840
return torch.as_tensor(
3941
np.empty(
4042
(),

src/probnum/randprocs/_random_process.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
import numpy as np
77

8-
from probnum import randvars
9-
from probnum import utils as _utils
8+
from probnum import randvars, utils as _utils
109
from probnum.typing import DTypeArgType, IntArgType, ShapeArgType
1110

1211
_InputType = TypeVar("InputType")
@@ -68,7 +67,7 @@ def __repr__(self) -> str:
6867
)
6968

7069
@abc.abstractmethod
71-
def __call__(self, args: _InputType) -> randvars.RandomVariable[_OutputType]:
70+
def __call__(self, args: _InputType) -> randvars.RandomVariable:
7271
"""Evaluate the random process at a set of input arguments.
7372
7473
Parameters

tests/test_randvars/test_normal/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Test cases defining random variables with a normal distribution."""
2+
3+
from pytest_cases import case, parametrize
4+
5+
from probnum import backend, randvars
6+
from probnum.problems.zoo.linalg import random_spd_matrix
7+
from probnum.typing import ScalarLike
8+
9+
10+
@case(tags=["univariate"])
11+
@parametrize("mean", (-1.0, 1))
12+
@parametrize("var", (3.0, 2))
13+
def case_univariate(mean: ScalarLike, var: ScalarLike) -> randvars.Normal:
14+
return randvars.Normal(mean, var)
15+
16+
17+
@case(tags=["vectorvariate"])
18+
@parametrize("dim", [1, 2, 5, 10, 20])
19+
def case_vectorvariate(dim: int) -> randvars.Normal:
20+
mean = backend.random.standard_normal(backend.random.seed(654 + dim), shape=(dim,))
21+
cov = random_spd_matrix(backend.random.seed(846), dim)
22+
23+
return randvars.Normal(mean, cov)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Test properties of normal random variables."""
2+
import numpy as np
3+
import scipy.stats
4+
from pytest_cases import parametrize_with_cases
5+
6+
from probnum import backend
7+
8+
9+
@parametrize_with_cases("rv", cases=".cases", has_tag=["univariate"])
10+
def test_entropy(rv):
11+
scipy_entropy = scipy.stats.norm.entropy(
12+
loc=backend.to_numpy(rv.mean),
13+
scale=backend.to_numpy(rv.std),
14+
)
15+
16+
np.testing.assert_allclose(backend.to_numpy(rv.entropy), scipy_entropy)
File renamed without changes.

0 commit comments

Comments
 (0)