Skip to content

Commit 477c825

Browse files
uniform sampling on SO(n) in torch backend
1 parent abf0739 commit 477c825

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

src/probnum/backend/random/_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _uniform_so_group_pushforward_fn(omega: jnp.ndarray) -> jnp.ndarray:
8686

8787
D = jnp.append(
8888
D,
89-
(-1.0 if n % 2 == 0 else 1.0) * jnp.prod(D[:-1]),
89+
(-1.0 if n % 2 == 0 else 1.0) * jnp.prod(D),
9090
)
9191

9292
return D[:, None] * H

src/probnum/backend/random/_torch.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
from torch.distributions.utils import broadcast_all
66

7+
from probnum.typing import DTypeArgType, ShapeArgType
8+
79
_RNG_STATE_SIZE = torch.Generator().get_state().shape[0]
810

911

@@ -44,6 +46,67 @@ def gamma(
4446
return torch._standard_gamma(a.expand(res_shape), rng) * scale.expand(res_shape)
4547

4648

49+
def uniform_so_group(
50+
seed: np.random.SeedSequence,
51+
n: int,
52+
shape: ShapeArgType = (),
53+
dtype: DTypeArgType = torch.double,
54+
) -> torch.Tensor:
55+
if n == 1:
56+
return torch.ones(shape + (1, 1), dtype=dtype)
57+
58+
omega = standard_normal(seed, shape=shape + (n - 1, n), dtype=dtype)
59+
60+
sample = _uniform_so_group_pushforward_fn(omega.reshape((-1, n - 1, n)))
61+
62+
return sample.reshape(shape + (n, n))
63+
64+
65+
@torch.jit.script
66+
def _uniform_so_group_pushforward_fn(omega: torch.Tensor) -> torch.Tensor:
67+
n = omega.shape[-1]
68+
69+
assert omega.ndim == 3 and omega.shape[-2] == n - 1
70+
71+
samples = []
72+
73+
for sample_idx in range(omega.shape[0]):
74+
X = torch.triu(omega[sample_idx, :, :])
75+
X_diag = torch.diag(X)
76+
77+
D = torch.where(
78+
X_diag != 0,
79+
torch.sign(X_diag),
80+
torch.ones((), dtype=omega.dtype),
81+
)
82+
83+
row_norms_sq = torch.sum(X ** 2, dim=1)
84+
85+
diag_indices = torch.arange(n - 1)
86+
X[diag_indices, diag_indices] = torch.sqrt(row_norms_sq) * D
87+
88+
X /= torch.sqrt((row_norms_sq - X_diag ** 2 + torch.diag(X) ** 2) / 2.0)[
89+
:, None
90+
]
91+
92+
H = torch.eye(n, dtype=omega.dtype)
93+
94+
for idx in range(n - 1):
95+
H -= torch.outer(H @ X[idx, :], X[idx, :])
96+
97+
D = torch.cat(
98+
(
99+
D,
100+
(-1.0 if n % 2 == 0 else 1.0) * torch.prod(D, dim=0, keepdim=True),
101+
),
102+
dim=0,
103+
)
104+
105+
samples.append(D[:, None] * H)
106+
107+
return torch.stack(samples, dim=0)
108+
109+
47110
def _make_rng(seed: np.random.SeedSequence) -> torch.Generator:
48111
rng = torch.Generator()
49112

0 commit comments

Comments
 (0)