|
10 | 10 | //! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`. |
11 | 11 |
|
12 | 12 | #![cfg(feature = "alloc")] |
13 | | -use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; |
| 13 | +use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal, multi::MultiDistribution}; |
14 | 14 | use core::fmt; |
15 | 15 | use num_traits::{Float, NumCast}; |
16 | 16 | use rand::Rng; |
@@ -68,26 +68,29 @@ where |
68 | 68 | } |
69 | 69 | } |
70 | 70 |
|
71 | | -impl<F, const N: usize> Distribution<[F; N]> for DirichletFromGamma<F, N> |
| 71 | +impl<F, const N: usize> MultiDistribution<F> for DirichletFromGamma<F, N> |
72 | 72 | where |
73 | 73 | F: Float, |
74 | 74 | StandardNormal: Distribution<F>, |
75 | 75 | Exp1: Distribution<F>, |
76 | 76 | Open01: Distribution<F>, |
77 | 77 | { |
78 | | - fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] { |
79 | | - let mut samples = [F::zero(); N]; |
| 78 | + fn sample_len(&self) -> usize { |
| 79 | + N |
| 80 | + } |
| 81 | + fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) { |
| 82 | + assert_eq!(output.len(), N); |
| 83 | + |
80 | 84 | let mut sum = F::zero(); |
81 | 85 |
|
82 | | - for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { |
| 86 | + for (s, g) in output.iter_mut().zip(self.samplers.iter()) { |
83 | 87 | *s = g.sample(rng); |
84 | 88 | sum = sum + *s; |
85 | 89 | } |
86 | 90 | let invacc = F::one() / sum; |
87 | | - for s in samples.iter_mut() { |
| 91 | + for s in output.iter_mut() { |
88 | 92 | *s = *s * invacc; |
89 | 93 | } |
90 | | - samples |
91 | 94 | } |
92 | 95 | } |
93 | 96 |
|
@@ -149,24 +152,27 @@ where |
149 | 152 | } |
150 | 153 | } |
151 | 154 |
|
152 | | -impl<F, const N: usize> Distribution<[F; N]> for DirichletFromBeta<F, N> |
| 155 | +impl<F, const N: usize> MultiDistribution<F> for DirichletFromBeta<F, N> |
153 | 156 | where |
154 | 157 | F: Float, |
155 | 158 | StandardNormal: Distribution<F>, |
156 | 159 | Exp1: Distribution<F>, |
157 | 160 | Open01: Distribution<F>, |
158 | 161 | { |
159 | | - fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] { |
160 | | - let mut samples = [F::zero(); N]; |
| 162 | + fn sample_len(&self) -> usize { |
| 163 | + N |
| 164 | + } |
| 165 | + fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) { |
| 166 | + assert_eq!(output.len(), N); |
| 167 | + |
161 | 168 | let mut acc = F::one(); |
162 | 169 |
|
163 | | - for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { |
| 170 | + for (s, beta) in output.iter_mut().zip(self.samplers.iter()) { |
164 | 171 | let beta_sample = beta.sample(rng); |
165 | 172 | *s = acc * beta_sample; |
166 | 173 | acc = acc * (F::one() - beta_sample); |
167 | 174 | } |
168 | | - samples[N - 1] = acc; |
169 | | - samples |
| 175 | + output[N - 1] = acc; |
170 | 176 | } |
171 | 177 | } |
172 | 178 |
|
@@ -208,7 +214,8 @@ where |
208 | 214 | /// |
209 | 215 | /// ``` |
210 | 216 | /// use rand::prelude::*; |
211 | | -/// use rand_distr::Dirichlet; |
| 217 | +/// use rand_distr::multi::Dirichlet; |
| 218 | +/// use rand_distr::multi::MultiDistribution; |
212 | 219 | /// |
213 | 220 | /// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); |
214 | 221 | /// let samples = dirichlet.sample(&mut rand::rng()); |
@@ -259,7 +266,7 @@ impl fmt::Display for Error { |
259 | 266 | "failed to create required Gamma distribution for Dirichlet distribution" |
260 | 267 | } |
261 | 268 | Error::FailedToCreateBeta => { |
262 | | - "failed to create required Beta distribition for Dirichlet distribution" |
| 269 | + "failed to create required Beta distribution for Dirichlet distribution" |
263 | 270 | } |
264 | 271 | }) |
265 | 272 | } |
@@ -315,21 +322,34 @@ where |
315 | 322 | } |
316 | 323 | } |
317 | 324 |
|
318 | | -impl<F, const N: usize> Distribution<[F; N]> for Dirichlet<F, N> |
| 325 | +impl<F, const N: usize> MultiDistribution<F> for Dirichlet<F, N> |
319 | 326 | where |
320 | 327 | F: Float, |
321 | 328 | StandardNormal: Distribution<F>, |
322 | 329 | Exp1: Distribution<F>, |
323 | 330 | Open01: Distribution<F>, |
324 | 331 | { |
325 | | - fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] { |
| 332 | + fn sample_len(&self) -> usize { |
| 333 | + N |
| 334 | + } |
| 335 | + fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) { |
326 | 336 | match &self.repr { |
327 | | - DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), |
328 | | - DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), |
| 337 | + DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_slice(rng, output), |
| 338 | + DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_slice(rng, output), |
329 | 339 | } |
330 | 340 | } |
331 | 341 | } |
332 | 342 |
|
| 343 | +impl<F, const N: usize> Distribution<Vec<F>> for Dirichlet<F, N> |
| 344 | +where |
| 345 | + F: Float + Default, |
| 346 | + StandardNormal: Distribution<F>, |
| 347 | + Exp1: Distribution<F>, |
| 348 | + Open01: Distribution<F>, |
| 349 | +{ |
| 350 | + distribution_impl!(F); |
| 351 | +} |
| 352 | + |
333 | 353 | #[cfg(test)] |
334 | 354 | mod test { |
335 | 355 | use super::*; |
@@ -403,7 +423,7 @@ mod test { |
403 | 423 | let alpha_sum: f64 = alpha.iter().sum(); |
404 | 424 | let expected_mean = alpha.map(|x| x / alpha_sum); |
405 | 425 | for i in 0..N { |
406 | | - assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); |
| 426 | + average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); |
407 | 427 | } |
408 | 428 | } |
409 | 429 |
|
|
0 commit comments