Skip to content

Commit ebecd1f

Browse files
Further tests for SO(n) sampling
1 parent 477c825 commit ebecd1f

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/probnum/compat/_core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def to_numpy(*xs: Union[backend.ndarray, linops.LinearOperator]) -> Tuple[np.nda
2323

2424
res.append(x)
2525

26+
if len(xs) == 1:
27+
return res[0]
28+
2629
return tuple(res)
2730

2831

tests/test_backend/test_random/test_uniform_so_group.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import numpy as np
12
import pytest_cases
23

34
from probnum import backend, compat
45
from probnum.typing import SeedLike, ShapeType
56

67

7-
@pytest_cases.fixture
8+
@pytest_cases.fixture(scope="module")
89
@pytest_cases.parametrize("seed", (234789, 7890))
910
@pytest_cases.parametrize("n", (1, 2, 5, 9))
1011
@pytest_cases.parametrize("shape", ((), (1,), (2,), (3, 2)))
@@ -28,3 +29,11 @@ def test_orthogonal(so_group_sample: backend.ndarray):
2829
backend.broadcast_arrays(backend.eye(n), so_group_sample)[0],
2930
atol=1e-6 if so_group_sample.dtype == backend.single else 1e-12,
3031
)
32+
33+
34+
def test_determinant_1(so_group_sample: backend.ndarray):
35+
compat.testing.assert_allclose(
36+
np.linalg.det(compat.to_numpy(so_group_sample)),
37+
1.0,
38+
rtol=2e-6 if so_group_sample.dtype == backend.single else 1e-7,
39+
)

0 commit comments

Comments
 (0)