Skip to content

Commit 099c576

Browse files
Fix normal arithmetic and cholesky updates
1 parent ef6de40 commit 099c576

File tree

7 files changed

+89
-41
lines changed

7 files changed

+89
-41
lines changed

src/probnum/randvars/_arithmetic.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _rv_binary_op(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable
124124

125125
def _make_rv_binary_op_result_shape_dtype_sample_fn(op_fn, rv1, rv2):
126126
def sample(seed, sample_shape):
127-
seed1, seed2, _ = backend.random.split(seed, 3)
127+
seed1, seed2 = backend.random.split(seed, 2)
128128

129129
return op_fn(
130130
rv1.sample(seed=seed1, sample_shape=sample_shape),
@@ -294,9 +294,17 @@ def _matmul_normal_constant(norm_rv: _Normal, constant_rv: _Constant) -> _Normal
294294
mean = norm_rv.mean @ constant_rv.support
295295
cov = constant_rv.support.T @ (norm_rv.cov @ constant_rv.support)
296296

297-
if cov.shape == () and mean.shape == (1,):
297+
if mean.shape == ():
298+
cov = cov.reshape(())
299+
300+
if cov_cholesky is not None:
301+
cov_cholesky = cov_cholesky.reshape(())
302+
elif mean.shape == (1,):
298303
cov = cov.reshape((1, 1))
299304

305+
if cov_cholesky is not None:
306+
cov_cholesky = cov_cholesky.reshape((1, 1))
307+
300308
return _Normal(mean=mean, cov=cov, cov_cholesky=cov_cholesky)
301309

302310
# This part does not do the Cholesky update,
@@ -335,11 +343,22 @@ def _matmul_constant_normal(constant_rv: _Constant, norm_rv: _Normal) -> _Normal
335343
)
336344
else:
337345
cov_cholesky = None
338-
return _Normal(
339-
mean=constant_rv.support @ norm_rv.mean,
340-
cov=constant_rv.support @ (norm_rv.cov @ constant_rv.support.T),
341-
cov_cholesky=cov_cholesky,
342-
)
346+
347+
mean = constant_rv.support @ norm_rv.mean
348+
cov = constant_rv.support @ (norm_rv.cov @ constant_rv.support.T)
349+
350+
if mean.shape == ():
351+
cov = cov.reshape(())
352+
353+
if cov_cholesky is not None:
354+
cov_cholesky = cov_cholesky.reshape(())
355+
elif mean.shape == (1,):
356+
cov = cov.reshape((1, 1))
357+
358+
if cov_cholesky is not None:
359+
cov_cholesky = cov_cholesky.reshape((1, 1))
360+
361+
return _Normal(mean=mean, cov=cov, cov_cholesky=cov_cholesky)
343362

344363
# This part does not do the Cholesky update,
345364
# because of performance configurations: currently, there is no way of switching

src/probnum/randvars/_normal.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,6 @@ def __init__(
120120

121121
if mean.ndim == 0:
122122
# Scalar Gaussian
123-
if self._cov_cholesky is None:
124-
self._cov_cholesky = backend.sqrt(cov)
125-
126123
self.__cov_op_cholesky = None
127124

128125
super().__init__(
@@ -255,6 +252,8 @@ def compute_cov_cholesky(
255252
lower=True,
256253
),
257254
)
255+
elif self.ndim == 0:
256+
self._cov_cholesky = backend.sqrt(self.cov)
258257
else:
259258
self.__cov_op_cholesky = linops.aslinop(
260259
backend.to_numpy(
@@ -275,10 +274,7 @@ def cov_cholesky_is_precomputed(self) -> bool:
275274
initialization or if (ii) the property `self.cov_cholesky` has
276275
been called before.
277276
"""
278-
if self.__cov_op_cholesky is None:
279-
return False
280-
281-
return True
277+
return self._cov_cholesky is not None or self.__cov_op_cholesky is not None
282278

283279
def __getitem__(self, key: ArrayLikeGetitemArgType) -> "Normal":
284280
"""Marginalization in multi- and matrixvariate normal random variables,

tests/test_randvars/test_arithmetic/conftest.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,67 @@
22
import numpy as np
33
import pytest
44

5-
from probnum import linops, randvars
5+
from probnum import backend, linops, randvars
66
from probnum.problems.zoo.linalg import random_spd_matrix
77
from probnum.typing import ShapeArgType
8+
from tests.testing import seed_from_args
89

910

1011
@pytest.fixture
11-
def rng() -> np.random.Generator:
12-
return np.random.default_rng(42)
12+
def constant(shape_const: ShapeArgType) -> randvars.Constant:
13+
seed = seed_from_args(shape_const, 19836)
1314

14-
15-
@pytest.fixture
16-
def constant(shape_const: ShapeArgType, rng: np.random.Generator) -> randvars.Constant:
17-
return randvars.Constant(support=rng.normal(size=shape_const))
15+
return randvars.Constant(
16+
support=backend.random.standard_normal(seed, shape=shape_const)
17+
)
1818

1919

2020
@pytest.fixture
2121
def multivariate_normal(
22-
shape: ShapeArgType, precompute_cov_cholesky: bool, rng: np.random.Generator
22+
shape: ShapeArgType, precompute_cov_cholesky: bool
2323
) -> randvars.Normal:
24+
seed = seed_from_args(shape, precompute_cov_cholesky, 1908)
25+
seed_mean, seed_cov = backend.random.split(seed)
26+
2427
rv = randvars.Normal(
25-
mean=rng.normal(size=shape),
26-
cov=random_spd_matrix(rng=rng, dim=shape[0]),
28+
mean=backend.random.standard_normal(seed_mean, shape=shape),
29+
cov=random_spd_matrix(seed_cov, dim=shape[0]),
2730
)
2831
if precompute_cov_cholesky:
29-
rv.precompute_cov_cholesky()
32+
rv.compute_cov_cholesky()
3033
return rv
3134

3235

3336
@pytest.fixture
3437
def matrixvariate_normal(
35-
shape: ShapeArgType, precompute_cov_cholesky: bool, rng: np.random.Generator
38+
shape: ShapeArgType, precompute_cov_cholesky: bool
3639
) -> randvars.Normal:
40+
seed = seed_from_args(shape, precompute_cov_cholesky, 354)
41+
seed_mean, seed_cov_A, seed_cov_B = backend.random.split(seed, num=3)
42+
3743
rv = randvars.Normal(
38-
mean=rng.normal(size=shape),
44+
mean=backend.random.standard_normal(seed_mean, shape=shape),
3945
cov=linops.Kronecker(
40-
A=random_spd_matrix(dim=shape[0], rng=rng),
41-
B=random_spd_matrix(dim=shape[1], rng=rng),
46+
A=random_spd_matrix(seed_cov_A, dim=shape[0]),
47+
B=random_spd_matrix(seed_cov_B, dim=shape[1]),
4248
),
4349
)
4450
if precompute_cov_cholesky:
45-
rv.precompute_cov_cholesky()
51+
rv.compute_cov_cholesky()
4652
return rv
4753

4854

4955
@pytest.fixture
5056
def symmetric_matrixvariate_normal(
51-
shape: ShapeArgType, precompute_cov_cholesky: bool, rng: np.random.Generator
57+
shape: ShapeArgType, precompute_cov_cholesky: bool
5258
) -> randvars.Normal:
59+
seed = seed_from_args(shape, precompute_cov_cholesky, 246)
60+
seed_mean, seed_cov = backend.random.split(seed)
61+
5362
rv = randvars.Normal(
54-
mean=random_spd_matrix(dim=shape[0], rng=rng),
55-
cov=linops.SymmetricKronecker(A=random_spd_matrix(dim=shape[0], rng=rng)),
63+
mean=random_spd_matrix(seed_mean, dim=shape[0]),
64+
cov=linops.SymmetricKronecker(A=random_spd_matrix(seed_cov, dim=shape[0])),
5665
)
5766
if precompute_cov_cholesky:
58-
rv.precompute_cov_cholesky()
67+
rv.compute_cov_cholesky()
5968
return rv

tests/test_randvars/test_arithmetic/test_generic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
@pytest.mark.parametrize("shape,dtype", [((5,), np.single), ((2, 3), np.double)])
1212
def test_generic_randvar_dtype_shape_inference(shape: ShapeArgType, dtype: DTypeLike):
1313
x = randvars.RandomVariable(
14-
shape=shape, dtype=dtype, sample=lambda size, rng: np.zeros(size + shape)
14+
shape=shape,
15+
dtype=dtype,
16+
sample=lambda seed, sample_shape: np.zeros(sample_shape + shape),
1517
)
1618
y = np.array(5.0)
1719
z = x + y

tests/test_utils/test_linalg/test_cholesky_updates.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
import probnum.utils.linalg as utlin
5+
from probnum import backend
56
from probnum.problems.zoo.linalg import random_spd_matrix
67

78

@@ -13,20 +14,28 @@ def even_ndim():
1314

1415

1516
@pytest.fixture
16-
def rng():
17-
return np.random.default_rng(seed=123)
17+
def spdmats(even_ndim):
18+
seed = backend.random.seed(abs(hash(even_ndim)))
19+
seed1, seed2 = backend.random.split(seed, num=2)
20+
21+
spdmat1 = random_spd_matrix(seed1, dim=even_ndim)
22+
spdmat2 = random_spd_matrix(seed2, dim=even_ndim)
23+
24+
return spdmat1, spdmat2
1825

1926

2027
@pytest.fixture
21-
def spdmat1(even_ndim, rng):
22-
return random_spd_matrix(rng, dim=even_ndim)
28+
def spdmat1(spdmats):
29+
return spdmats[0]
2330

2431

2532
@pytest.fixture
26-
def spdmat2(even_ndim, rng):
27-
return random_spd_matrix(rng, dim=even_ndim)
33+
def spdmat2(spdmats):
34+
return spdmats[1]
2835

2936

37+
@pytest.mark.skipif_backend(backend.Backend.JAX)
38+
@pytest.mark.skipif_backend(backend.Backend.TORCH)
3039
def test_cholesky_update(spdmat1, spdmat2):
3140
expected = np.linalg.cholesky(spdmat1 + spdmat2)
3241

@@ -36,6 +45,8 @@ def test_cholesky_update(spdmat1, spdmat2):
3645
np.testing.assert_allclose(expected, received)
3746

3847

48+
@pytest.mark.skipif_backend(backend.Backend.JAX)
49+
@pytest.mark.skipif_backend(backend.Backend.TORCH)
3950
def test_cholesky_optional(spdmat1, even_ndim):
4051
"""Assert that cholesky_update() transforms a non-square matrix square-root into a
4152
correct Cholesky factor."""
@@ -46,6 +57,8 @@ def test_cholesky_optional(spdmat1, even_ndim):
4657
np.testing.assert_allclose(expected, received)
4758

4859

60+
@pytest.mark.skipif_backend(backend.Backend.JAX)
61+
@pytest.mark.skipif_backend(backend.Backend.TORCH)
4962
def test_tril_to_positive_tril():
5063

5164
# Make a random tril matrix

tests/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .assertions import *
2+
from .random import seed_from_args
23
from .statistics import *

tests/testing/random.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from collections.abc import Hashable
2+
3+
from probnum import backend
4+
from probnum.typing import SeedType
5+
6+
7+
def seed_from_args(*args: Hashable) -> SeedType:
8+
return backend.random.seed(abs(sum(map(hash, args))))

0 commit comments

Comments
 (0)