Skip to content

Commit d2146ee

Browse files
authored
Merge pull request #30 from rust-random/push-ywmuottvwtvt
Remove const-generic size parameter from Dirichlet distribution
2 parents 85a6ee0 + bab5ad3 commit d2146ee

File tree

3 files changed

+54
-50
lines changed

3 files changed

+54
-50
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Changes
1515
- Moved `Dirichlet` into the new `multi` module and implement `MultiDistribution` for it (#18)
16+
- `Dirichlet` no longer uses `const` generics, which means that its size is not required at compile time. Essentially a revert of rand#1292. (#15)
17+
- Add `Dirichlet::new_with_size` constructor (#15)
1618

1719
## [0.5.2]
1820

src/multi/dirichlet.rs

Lines changed: 49 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,24 @@ use alloc::{boxed::Box, vec, vec::Vec};
2121

2222
#[derive(Clone, Debug, PartialEq)]
2323
#[cfg_attr(feature = "serde", serde_as)]
24-
struct DirichletFromGamma<F, const N: usize>
24+
struct DirichletFromGamma<F>
2525
where
2626
F: Float,
2727
StandardNormal: Distribution<F>,
2828
Exp1: Distribution<F>,
2929
Open01: Distribution<F>,
3030
{
31-
samplers: [Gamma<F>; N],
31+
samplers: Vec<Gamma<F>>,
3232
}
3333

3434
/// Error type returned from [`DirchletFromGamma::new`].
3535
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
3636
enum DirichletFromGammaError {
3737
/// Gamma::new(a, 1) failed.
3838
GammmaNewFailed,
39-
40-
/// gamma_dists.try_into() failed (in theory, this should not happen).
41-
GammaArrayCreationFailed,
4239
}
4340

44-
impl<F, const N: usize> DirichletFromGamma<F, N>
41+
impl<F> DirichletFromGamma<F>
4542
where
4643
F: Float,
4744
StandardNormal: Distribution<F>,
@@ -53,33 +50,32 @@ where
5350
/// This function is part of a private implementation detail.
5451
/// It assumes that the input is correct, so no validation of alpha is done.
5552
#[inline]
56-
fn new(alpha: [F; N]) -> Result<DirichletFromGamma<F, N>, DirichletFromGammaError> {
53+
fn new(alpha: &[F]) -> Result<DirichletFromGamma<F>, DirichletFromGammaError> {
5754
let mut gamma_dists = Vec::new();
5855
for a in alpha {
5956
let dist =
60-
Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?;
57+
Gamma::new(*a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?;
6158
gamma_dists.push(dist);
6259
}
6360
Ok(DirichletFromGamma {
64-
samplers: gamma_dists
65-
.try_into()
66-
.map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?,
61+
samplers: gamma_dists,
6762
})
6863
}
6964
}
7065

71-
impl<F, const N: usize> MultiDistribution<F> for DirichletFromGamma<F, N>
66+
impl<F> MultiDistribution<F> for DirichletFromGamma<F>
7267
where
7368
F: Float,
7469
StandardNormal: Distribution<F>,
7570
Exp1: Distribution<F>,
7671
Open01: Distribution<F>,
7772
{
73+
#[inline]
7874
fn sample_len(&self) -> usize {
79-
N
75+
self.samplers.len()
8076
}
8177
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
82-
assert_eq!(output.len(), N);
78+
assert_eq!(output.len(), self.sample_len());
8379

8480
let mut sum = F::zero();
8581

@@ -96,7 +92,7 @@ where
9692

9793
#[derive(Clone, Debug, PartialEq)]
9894
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99-
struct DirichletFromBeta<F, const N: usize>
95+
struct DirichletFromBeta<F>
10096
where
10197
F: Float,
10298
StandardNormal: Distribution<F>,
@@ -113,7 +109,7 @@ enum DirichletFromBetaError {
113109
BetaNewFailed,
114110
}
115111

116-
impl<F, const N: usize> DirichletFromBeta<F, N>
112+
impl<F> DirichletFromBeta<F>
117113
where
118114
F: Float,
119115
StandardNormal: Distribution<F>,
@@ -125,15 +121,16 @@ where
125121
/// This function is part of a private implementation detail.
126122
/// It assumes that the input is correct, so no validation of alpha is done.
127123
#[inline]
128-
fn new(alpha: [F; N]) -> Result<DirichletFromBeta<F, N>, DirichletFromBetaError> {
124+
fn new(alpha: &[F]) -> Result<DirichletFromBeta<F>, DirichletFromBetaError> {
129125
// `alpha_rev_csum` is the reverse of the cumulative sum of the
130126
// reverse of `alpha[1..]`. E.g. if `alpha = [a0, a1, a2, a3]`, then
131127
// `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`.
132128
// Note that instances of DirichletFromBeta will always have N >= 2,
133129
// so the subtractions of 1, 2 and 3 from N in the following are safe.
134-
let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1];
135-
for k in 0..(N - 2) {
136-
alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k];
130+
let n = alpha.len();
131+
let mut alpha_rev_csum = vec![alpha[n - 1]; n - 1];
132+
for k in 0..(n - 2) {
133+
alpha_rev_csum[n - 3 - k] = alpha_rev_csum[n - 2 - k] + alpha[n - 2 - k];
137134
}
138135

139136
// Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example
@@ -142,7 +139,7 @@ where
142139
// Then pass each tuple to `Beta::new()` to create the `Beta`
143140
// instances.
144141
let mut beta_dists = Vec::new();
145-
for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) {
142+
for (&a, &b) in alpha[..(n - 1)].iter().zip(alpha_rev_csum.iter()) {
146143
let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?;
147144
beta_dists.push(dist);
148145
}
@@ -152,18 +149,19 @@ where
152149
}
153150
}
154151

155-
impl<F, const N: usize> MultiDistribution<F> for DirichletFromBeta<F, N>
152+
impl<F> MultiDistribution<F> for DirichletFromBeta<F>
156153
where
157154
F: Float,
158155
StandardNormal: Distribution<F>,
159156
Exp1: Distribution<F>,
160157
Open01: Distribution<F>,
161158
{
159+
#[inline]
162160
fn sample_len(&self) -> usize {
163-
N
161+
self.samplers.len() + 1
164162
}
165163
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
166-
assert_eq!(output.len(), N);
164+
assert_eq!(output.len(), self.sample_len());
167165

168166
let mut acc = F::one();
169167

@@ -172,24 +170,24 @@ where
172170
*s = acc * beta_sample;
173171
acc = acc * (F::one() - beta_sample);
174172
}
175-
output[N - 1] = acc;
173+
output[output.len() - 1] = acc;
176174
}
177175
}
178176

179177
#[derive(Clone, Debug, PartialEq)]
180178
#[cfg_attr(feature = "serde", serde_as)]
181-
enum DirichletRepr<F, const N: usize>
179+
enum DirichletRepr<F>
182180
where
183181
F: Float,
184182
StandardNormal: Distribution<F>,
185183
Exp1: Distribution<F>,
186184
Open01: Distribution<F>,
187185
{
188186
/// Dirichlet distribution that generates samples using the Gamma distribution.
189-
FromGamma(DirichletFromGamma<F, N>),
187+
FromGamma(DirichletFromGamma<F>),
190188

191189
/// Dirichlet distribution that generates samples using the Beta distribution.
192-
FromBeta(DirichletFromBeta<F, N>),
190+
FromBeta(DirichletFromBeta<F>),
193191
}
194192

195193
/// The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) `Dirichlet(α₁, α₂, ..., αₖ)`.
@@ -217,20 +215,20 @@ where
217215
/// use rand_distr::multi::Dirichlet;
218216
/// use rand_distr::multi::MultiDistribution;
219217
///
220-
/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
218+
/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
221219
/// let samples = dirichlet.sample(&mut rand::rng());
222-
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
220+
/// println!("{:?} is from a Dirichlet(&[1.0, 2.0, 3.0]) distribution", samples);
223221
/// ```
224222
#[cfg_attr(feature = "serde", serde_as)]
225223
#[derive(Clone, Debug, PartialEq)]
226-
pub struct Dirichlet<F, const N: usize>
224+
pub struct Dirichlet<F>
227225
where
228226
F: Float,
229227
StandardNormal: Distribution<F>,
230228
Exp1: Distribution<F>,
231229
Open01: Distribution<F>,
232230
{
233-
repr: DirichletRepr<F, N>,
231+
repr: DirichletRepr<F>,
234232
}
235233

236234
/// Error type returned from [`Dirichlet::new`].
@@ -275,7 +273,7 @@ impl fmt::Display for Error {
275273
#[cfg(feature = "std")]
276274
impl std::error::Error for Error {}
277275

278-
impl<F, const N: usize> Dirichlet<F, N>
276+
impl<F> Dirichlet<F>
279277
where
280278
F: Float,
281279
StandardNormal: Distribution<F>,
@@ -287,8 +285,8 @@ where
287285
/// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive,
288286
/// finite and not subnormal.
289287
#[inline]
290-
pub fn new(alpha: [F; N]) -> Result<Dirichlet<F, N>, Error> {
291-
if N < 2 {
288+
pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
289+
if alpha.len() < 2 {
292290
return Err(Error::AlphaTooShort);
293291
}
294292
for &ai in alpha.iter() {
@@ -322,15 +320,19 @@ where
322320
}
323321
}
324322

325-
impl<F, const N: usize> MultiDistribution<F> for Dirichlet<F, N>
323+
impl<F> MultiDistribution<F> for Dirichlet<F>
326324
where
327325
F: Float,
328326
StandardNormal: Distribution<F>,
329327
Exp1: Distribution<F>,
330328
Open01: Distribution<F>,
331329
{
330+
#[inline]
332331
fn sample_len(&self) -> usize {
333-
N
332+
match &self.repr {
333+
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_len(),
334+
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_len(),
335+
}
334336
}
335337
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
336338
match &self.repr {
@@ -340,7 +342,7 @@ where
340342
}
341343
}
342344

343-
impl<F, const N: usize> Distribution<Vec<F>> for Dirichlet<F, N>
345+
impl<F> Distribution<Vec<F>> for Dirichlet<F>
344346
where
345347
F: Float + Default,
346348
StandardNormal: Distribution<F>,
@@ -356,7 +358,7 @@ mod test {
356358

357359
#[test]
358360
fn test_dirichlet() {
359-
let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
361+
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
360362
let mut rng = crate::test::rng(221);
361363
let samples = d.sample(&mut rng);
362364
assert!(samples.into_iter().all(|x: f64| x > 0.0));
@@ -365,42 +367,42 @@ mod test {
365367
#[test]
366368
#[should_panic]
367369
fn test_dirichlet_invalid_length() {
368-
Dirichlet::new([0.5]).unwrap();
370+
Dirichlet::new(&[0.5]).unwrap();
369371
}
370372

371373
#[test]
372374
#[should_panic]
373375
fn test_dirichlet_alpha_zero() {
374-
Dirichlet::new([0.1, 0.0, 0.3]).unwrap();
376+
Dirichlet::new(&[0.1, 0.0, 0.3]).unwrap();
375377
}
376378

377379
#[test]
378380
#[should_panic]
379381
fn test_dirichlet_alpha_negative() {
380-
Dirichlet::new([0.1, -1.5, 0.3]).unwrap();
382+
Dirichlet::new(&[0.1, -1.5, 0.3]).unwrap();
381383
}
382384

383385
#[test]
384386
#[should_panic]
385387
fn test_dirichlet_alpha_nan() {
386-
Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap();
388+
Dirichlet::new(&[0.5, f64::NAN, 0.25]).unwrap();
387389
}
388390

389391
#[test]
390392
#[should_panic]
391393
fn test_dirichlet_alpha_subnormal() {
392-
Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap();
394+
Dirichlet::new(&[0.5, 1.5e-321, 0.25]).unwrap();
393395
}
394396

395397
#[test]
396398
#[should_panic]
397399
fn test_dirichlet_alpha_inf() {
398-
Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap();
400+
Dirichlet::new(&[0.5, f64::INFINITY, 0.25]).unwrap();
399401
}
400402

401403
#[test]
402404
fn dirichlet_distributions_can_be_compared() {
403-
assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0]));
405+
assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0]));
404406
}
405407

406408
/// Check that the means of the components of n samples from
@@ -410,7 +412,7 @@ mod test {
410412
/// This is a crude statistical test, but it will catch egregious
411413
/// mistakes. It will also also fail if any samples contain nan.
412414
fn check_dirichlet_means<const N: usize>(alpha: [f64; N], n: i32, rtol: f64, seed: u64) {
413-
let d = Dirichlet::new(alpha).unwrap();
415+
let d = Dirichlet::new(&alpha).unwrap();
414416
let mut rng = crate::test::rng(seed);
415417
let mut sums = [0.0; N];
416418
for _ in 0..n {

tests/value_stability.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,13 +502,13 @@ fn weibull_stability() {
502502
fn dirichlet_stability() {
503503
let mut rng = get_rng(223);
504504
assert_eq!(
505-
multi::Dirichlet::new([1.0, 2.0, 3.0])
505+
multi::Dirichlet::new(&[1.0, 2.0, 3.0])
506506
.unwrap()
507507
.sample(&mut rng),
508508
[0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
509509
);
510510
assert_eq!(
511-
multi::Dirichlet::new([8.0; 5]).unwrap().sample(&mut rng),
511+
multi::Dirichlet::new(&[8.0; 5]).unwrap().sample(&mut rng),
512512
[
513513
0.17684200044809556,
514514
0.29915953935953055,
@@ -519,7 +519,7 @@ fn dirichlet_stability() {
519519
);
520520
// Test stability for the case where all alphas are less than 0.1.
521521
assert_eq!(
522-
multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05])
522+
multi::Dirichlet::new(&[0.05, 0.025, 0.075, 0.05])
523523
.unwrap()
524524
.sample(&mut rng),
525525
[

0 commit comments

Comments
 (0)