|
4 | 4 | import torch |
5 | 5 | from torch.distributions.utils import broadcast_all |
6 | 6 |
|
| 7 | +from probnum.typing import DTypeArgType, ShapeArgType |
| 8 | + |
7 | 9 | _RNG_STATE_SIZE = torch.Generator().get_state().shape[0] |
8 | 10 |
|
9 | 11 |
|
@@ -44,6 +46,67 @@ def gamma( |
44 | 46 | return torch._standard_gamma(a.expand(res_shape), rng) * scale.expand(res_shape) |
45 | 47 |
|
46 | 48 |
|
| 49 | +def uniform_so_group( |
| 50 | + seed: np.random.SeedSequence, |
| 51 | + n: int, |
| 52 | + shape: ShapeArgType = (), |
| 53 | + dtype: DTypeArgType = torch.double, |
| 54 | +) -> torch.Tensor: |
| 55 | + if n == 1: |
| 56 | + return torch.ones(shape + (1, 1), dtype=dtype) |
| 57 | + |
| 58 | + omega = standard_normal(seed, shape=shape + (n - 1, n), dtype=dtype) |
| 59 | + |
| 60 | + sample = _uniform_so_group_pushforward_fn(omega.reshape((-1, n - 1, n))) |
| 61 | + |
| 62 | + return sample.reshape(shape + (n, n)) |
| 63 | + |
| 64 | + |
| 65 | +@torch.jit.script |
| 66 | +def _uniform_so_group_pushforward_fn(omega: torch.Tensor) -> torch.Tensor: |
| 67 | + n = omega.shape[-1] |
| 68 | + |
| 69 | + assert omega.ndim == 3 and omega.shape[-2] == n - 1 |
| 70 | + |
| 71 | + samples = [] |
| 72 | + |
| 73 | + for sample_idx in range(omega.shape[0]): |
| 74 | + X = torch.triu(omega[sample_idx, :, :]) |
| 75 | + X_diag = torch.diag(X) |
| 76 | + |
| 77 | + D = torch.where( |
| 78 | + X_diag != 0, |
| 79 | + torch.sign(X_diag), |
| 80 | + torch.ones((), dtype=omega.dtype), |
| 81 | + ) |
| 82 | + |
| 83 | + row_norms_sq = torch.sum(X ** 2, dim=1) |
| 84 | + |
| 85 | + diag_indices = torch.arange(n - 1) |
| 86 | + X[diag_indices, diag_indices] = torch.sqrt(row_norms_sq) * D |
| 87 | + |
| 88 | + X /= torch.sqrt((row_norms_sq - X_diag ** 2 + torch.diag(X) ** 2) / 2.0)[ |
| 89 | + :, None |
| 90 | + ] |
| 91 | + |
| 92 | + H = torch.eye(n, dtype=omega.dtype) |
| 93 | + |
| 94 | + for idx in range(n - 1): |
| 95 | + H -= torch.outer(H @ X[idx, :], X[idx, :]) |
| 96 | + |
| 97 | + D = torch.cat( |
| 98 | + ( |
| 99 | + D, |
| 100 | + (-1.0 if n % 2 == 0 else 1.0) * torch.prod(D, dim=0, keepdim=True), |
| 101 | + ), |
| 102 | + dim=0, |
| 103 | + ) |
| 104 | + |
| 105 | + samples.append(D[:, None] * H) |
| 106 | + |
| 107 | + return torch.stack(samples, dim=0) |
| 108 | + |
| 109 | + |
47 | 110 | def _make_rng(seed: np.random.SeedSequence) -> torch.Generator: |
48 | 111 | rng = torch.Generator() |
49 | 112 |
|
|
0 commit comments