Skip to content

Commit 38908ae

Browse files
Bugfix for gamma sampling
1 parent 88c1867 commit 38908ae

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/probnum/backend/random/_torch.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,25 @@ def standard_normal(seed: np.random.SeedSequence, shape=(), dtype=torch.double):
2727

2828
def gamma(
2929
seed: np.random.SeedSequence,
30-
a: torch.Tensor,
31-
scale=1.0,
30+
shape_param: torch.Tensor,
31+
scale_param=1.0,
3232
shape=(),
3333
dtype=torch.double,
3434
):
3535
rng = _make_rng(seed)
3636

37-
a = a.to(dtype)
38-
scale = scale.to(dtype)
37+
shape_param = torch.as_tensor(shape_param, dtype=dtype)
38+
scale_param = torch.as_tensor(scale_param, dtype=dtype)
3939

4040
# Adapted version of
4141
# https://github.com/pytorch/pytorch/blob/afff38182457f3500c265f232310438dded0e57d/torch/distributions/gamma.py#L59-L63
42-
a, scale = broadcast_all(a, scale)
42+
shape_param, scale_param = broadcast_all(shape_param, scale_param)
4343

44-
res_shape = shape + a.shape
44+
res_shape = shape + shape_param.shape
4545

46-
return torch._standard_gamma(a.expand(res_shape), rng) * scale.expand(res_shape)
46+
return torch._standard_gamma(
47+
shape_param.expand(res_shape), rng
48+
) * scale_param.expand(res_shape)
4749

4850

4951
def uniform_so_group(

0 commit comments

Comments
 (0)