Skip to content

Commit 63f0430

Browse files
MultiDistribution (#18)
Co-authored-by: Diggory Hardy <git@dhardy.name>
1 parent 26da522 commit 63f0430

File tree

5 files changed

+93
-27
lines changed

5 files changed

+93
-27
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [0.6.0] — Unreleased
88
- Bump to MSRV 1.85.0 and Edition 2024 in line with `rand` (#27)
99

10+
### Additions
11+
- `MultiDistribution` trait to sample more efficiently from multi-dimensional distributions (#18)
12+
13+
### Changes
14+
- Moved `Dirichlet` into the new `multi` module and implement `MultiDistribution` for it (#18)
15+
1016
## [0.5.2]
1117

1218
### API Changes

src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
//! - [`Beta`] distribution
7373
//! - [`Triangular`] distribution
7474
//! - Multivariate probability distributions
75-
//! - [`Dirichlet`] distribution
75+
//! - [`multi::Dirichlet`] distribution
7676
//! - [`UnitSphere`] distribution
7777
//! - [`UnitBall`] distribution
7878
//! - [`UnitCircle`] distribution
@@ -100,8 +100,6 @@ pub use self::beta::{Beta, Error as BetaError};
100100
pub use self::binomial::{Binomial, Error as BinomialError};
101101
pub use self::cauchy::{Cauchy, Error as CauchyError};
102102
pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError};
103-
#[cfg(feature = "alloc")]
104-
pub use self::dirichlet::{Dirichlet, Error as DirichletError};
105103
pub use self::exponential::{Error as ExpError, Exp, Exp1};
106104
pub use self::fisher_f::{Error as FisherFError, FisherF};
107105
pub use self::frechet::{Error as FrechetError, Frechet};
@@ -130,6 +128,8 @@ pub use student_t::StudentT;
130128

131129
pub use num_traits;
132130

131+
#[cfg(feature = "alloc")]
132+
pub mod multi;
133133
#[cfg(feature = "alloc")]
134134
pub mod weighted;
135135

@@ -188,7 +188,6 @@ mod beta;
188188
mod binomial;
189189
mod cauchy;
190190
mod chi_squared;
191-
mod dirichlet;
192191
mod exponential;
193192
mod fisher_f;
194193
mod frechet;

src/dirichlet.rs renamed to src/multi/dirichlet.rs

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
//! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`.
1111
1212
#![cfg(feature = "alloc")]
13-
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
13+
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal, multi::MultiDistribution};
1414
use core::fmt;
1515
use num_traits::{Float, NumCast};
1616
use rand::Rng;
@@ -68,26 +68,29 @@ where
6868
}
6969
}
7070

71-
impl<F, const N: usize> Distribution<[F; N]> for DirichletFromGamma<F, N>
71+
impl<F, const N: usize> MultiDistribution<F> for DirichletFromGamma<F, N>
7272
where
7373
F: Float,
7474
StandardNormal: Distribution<F>,
7575
Exp1: Distribution<F>,
7676
Open01: Distribution<F>,
7777
{
78-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
79-
let mut samples = [F::zero(); N];
78+
fn sample_len(&self) -> usize {
79+
N
80+
}
81+
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
82+
assert_eq!(output.len(), N);
83+
8084
let mut sum = F::zero();
8185

82-
for (s, g) in samples.iter_mut().zip(self.samplers.iter()) {
86+
for (s, g) in output.iter_mut().zip(self.samplers.iter()) {
8387
*s = g.sample(rng);
8488
sum = sum + *s;
8589
}
8690
let invacc = F::one() / sum;
87-
for s in samples.iter_mut() {
91+
for s in output.iter_mut() {
8892
*s = *s * invacc;
8993
}
90-
samples
9194
}
9295
}
9396

@@ -149,24 +152,27 @@ where
149152
}
150153
}
151154

152-
impl<F, const N: usize> Distribution<[F; N]> for DirichletFromBeta<F, N>
155+
impl<F, const N: usize> MultiDistribution<F> for DirichletFromBeta<F, N>
153156
where
154157
F: Float,
155158
StandardNormal: Distribution<F>,
156159
Exp1: Distribution<F>,
157160
Open01: Distribution<F>,
158161
{
159-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
160-
let mut samples = [F::zero(); N];
162+
fn sample_len(&self) -> usize {
163+
N
164+
}
165+
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
166+
assert_eq!(output.len(), N);
167+
161168
let mut acc = F::one();
162169

163-
for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) {
170+
for (s, beta) in output.iter_mut().zip(self.samplers.iter()) {
164171
let beta_sample = beta.sample(rng);
165172
*s = acc * beta_sample;
166173
acc = acc * (F::one() - beta_sample);
167174
}
168-
samples[N - 1] = acc;
169-
samples
175+
output[N - 1] = acc;
170176
}
171177
}
172178

@@ -208,7 +214,8 @@ where
208214
///
209215
/// ```
210216
/// use rand::prelude::*;
211-
/// use rand_distr::Dirichlet;
217+
/// use rand_distr::multi::Dirichlet;
218+
/// use rand_distr::multi::MultiDistribution;
212219
///
213220
/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
214221
/// let samples = dirichlet.sample(&mut rand::rng());
@@ -259,7 +266,7 @@ impl fmt::Display for Error {
259266
"failed to create required Gamma distribution for Dirichlet distribution"
260267
}
261268
Error::FailedToCreateBeta => {
262-
"failed to create required Beta distribition for Dirichlet distribution"
269+
"failed to create required Beta distribution for Dirichlet distribution"
263270
}
264271
})
265272
}
@@ -315,21 +322,34 @@ where
315322
}
316323
}
317324

318-
impl<F, const N: usize> Distribution<[F; N]> for Dirichlet<F, N>
325+
impl<F, const N: usize> MultiDistribution<F> for Dirichlet<F, N>
319326
where
320327
F: Float,
321328
StandardNormal: Distribution<F>,
322329
Exp1: Distribution<F>,
323330
Open01: Distribution<F>,
324331
{
325-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
332+
fn sample_len(&self) -> usize {
333+
N
334+
}
335+
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
326336
match &self.repr {
327-
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng),
328-
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng),
337+
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_slice(rng, output),
338+
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_slice(rng, output),
329339
}
330340
}
331341
}
332342

343+
impl<F, const N: usize> Distribution<Vec<F>> for Dirichlet<F, N>
344+
where
345+
F: Float + Default,
346+
StandardNormal: Distribution<F>,
347+
Exp1: Distribution<F>,
348+
Open01: Distribution<F>,
349+
{
350+
distribution_impl!(F);
351+
}
352+
333353
#[cfg(test)]
334354
mod test {
335355
use super::*;
@@ -403,7 +423,7 @@ mod test {
403423
let alpha_sum: f64 = alpha.iter().sum();
404424
let expected_mean = alpha.map(|x| x / alpha_sum);
405425
for i in 0..N {
406-
assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
426+
average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
407427
}
408428
}
409429

src/multi/mod.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright 2025 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
//! Contains Multi-dimensional distributions.
10+
//!
11+
//! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations.
12+
//! All multi-dimensional distributions implement `MultiDistribution` instead of the `Distribution` trait.
13+
14+
use rand::Rng;
15+
16+
/// A standard abstraction for distributions with multi-dimensional results
17+
pub trait MultiDistribution<T> {
18+
/// returns the length of one sample (dimension of the distribution)
19+
fn sample_len(&self) -> usize;
20+
/// samples from the distribution and writes the result to `output`
21+
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [T]);
22+
}
23+
24+
macro_rules! distribution_impl {
25+
($scalar:ident) => {
26+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<$scalar> {
27+
use crate::multi::MultiDistribution;
28+
let mut buf = vec![Default::default(); self.sample_len()];
29+
self.sample_to_slice(rng, &mut buf);
30+
buf
31+
}
32+
};
33+
}
34+
35+
pub use dirichlet::Dirichlet;
36+
37+
mod dirichlet;

tests/value_stability.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,13 @@ fn weibull_stability() {
502502
fn dirichlet_stability() {
503503
let mut rng = get_rng(223);
504504
assert_eq!(
505-
rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()),
505+
multi::Dirichlet::new([1.0, 2.0, 3.0])
506+
.unwrap()
507+
.sample(&mut rng),
506508
[0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
507509
);
508510
assert_eq!(
509-
rng.sample(Dirichlet::new([8.0; 5]).unwrap()),
511+
multi::Dirichlet::new([8.0; 5]).unwrap().sample(&mut rng),
510512
[
511513
0.17684200044809556,
512514
0.29915953935953055,
@@ -517,7 +519,9 @@ fn dirichlet_stability() {
517519
);
518520
// Test stability for the case where all alphas are less than 0.1.
519521
assert_eq!(
520-
rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()),
522+
multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05])
523+
.unwrap()
524+
.sample(&mut rng),
521525
[
522526
0.00027580456855692104,
523527
2.296135759821706e-20,

0 commit comments

Comments
 (0)