@@ -21,27 +21,24 @@ use alloc::{boxed::Box, vec, vec::Vec};
2121
2222#[ derive( Clone , Debug , PartialEq ) ]
2323#[ cfg_attr( feature = "serde" , serde_as) ]
24- struct DirichletFromGamma < F , const N : usize >
24+ struct DirichletFromGamma < F >
2525where
2626 F : Float ,
2727 StandardNormal : Distribution < F > ,
2828 Exp1 : Distribution < F > ,
2929 Open01 : Distribution < F > ,
3030{
31- samplers : [ Gamma < F > ; N ] ,
31+ samplers : Vec < Gamma < F > > ,
3232}
3333
3434/// Error type returned from [`DirchletFromGamma::new`].
3535#[ derive( Clone , Copy , Debug , PartialEq , Eq ) ]
3636enum DirichletFromGammaError {
3737 /// Gamma::new(a, 1) failed.
3838 GammmaNewFailed ,
39-
40- /// gamma_dists.try_into() failed (in theory, this should not happen).
41- GammaArrayCreationFailed ,
4239}
4340
44- impl < F , const N : usize > DirichletFromGamma < F , N >
41+ impl < F > DirichletFromGamma < F >
4542where
4643 F : Float ,
4744 StandardNormal : Distribution < F > ,
@@ -53,33 +50,32 @@ where
5350 /// This function is part of a private implementation detail.
5451 /// It assumes that the input is correct, so no validation of alpha is done.
5552 #[ inline]
56- fn new ( alpha : [ F ; N ] ) -> Result < DirichletFromGamma < F , N > , DirichletFromGammaError > {
53+ fn new ( alpha : & [ F ] ) -> Result < DirichletFromGamma < F > , DirichletFromGammaError > {
5754 let mut gamma_dists = Vec :: new ( ) ;
5855 for a in alpha {
5956 let dist =
60- Gamma :: new ( a, F :: one ( ) ) . map_err ( |_| DirichletFromGammaError :: GammmaNewFailed ) ?;
57+ Gamma :: new ( * a, F :: one ( ) ) . map_err ( |_| DirichletFromGammaError :: GammmaNewFailed ) ?;
6158 gamma_dists. push ( dist) ;
6259 }
6360 Ok ( DirichletFromGamma {
64- samplers : gamma_dists
65- . try_into ( )
66- . map_err ( |_| DirichletFromGammaError :: GammaArrayCreationFailed ) ?,
61+ samplers : gamma_dists,
6762 } )
6863 }
6964}
7065
71- impl < F , const N : usize > MultiDistribution < F > for DirichletFromGamma < F , N >
66+ impl < F > MultiDistribution < F > for DirichletFromGamma < F >
7267where
7368 F : Float ,
7469 StandardNormal : Distribution < F > ,
7570 Exp1 : Distribution < F > ,
7671 Open01 : Distribution < F > ,
7772{
73+ #[ inline]
7874 fn sample_len ( & self ) -> usize {
79- N
75+ self . samplers . len ( )
8076 }
8177 fn sample_to_slice < R : Rng + ?Sized > ( & self , rng : & mut R , output : & mut [ F ] ) {
82- assert_eq ! ( output. len( ) , N ) ;
78+ assert_eq ! ( output. len( ) , self . sample_len ( ) ) ;
8379
8480 let mut sum = F :: zero ( ) ;
8581
9692
9793#[ derive( Clone , Debug , PartialEq ) ]
9894#[ cfg_attr( feature = "serde" , derive( serde:: Serialize , serde:: Deserialize ) ) ]
99- struct DirichletFromBeta < F , const N : usize >
95+ struct DirichletFromBeta < F >
10096where
10197 F : Float ,
10298 StandardNormal : Distribution < F > ,
@@ -113,7 +109,7 @@ enum DirichletFromBetaError {
113109 BetaNewFailed ,
114110}
115111
116- impl < F , const N : usize > DirichletFromBeta < F , N >
112+ impl < F > DirichletFromBeta < F >
117113where
118114 F : Float ,
119115 StandardNormal : Distribution < F > ,
@@ -125,15 +121,16 @@ where
125121 /// This function is part of a private implementation detail.
126122 /// It assumes that the input is correct, so no validation of alpha is done.
127123 #[ inline]
128- fn new ( alpha : [ F ; N ] ) -> Result < DirichletFromBeta < F , N > , DirichletFromBetaError > {
124+ fn new ( alpha : & [ F ] ) -> Result < DirichletFromBeta < F > , DirichletFromBetaError > {
129125 // `alpha_rev_csum` is the reverse of the cumulative sum of the
130126 // reverse of `alpha[1..]`. E.g. if `alpha = [a0, a1, a2, a3]`, then
131127 // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`.
132128 // Note that instances of DirichletFromBeta will always have N >= 2,
133129 // so the subtractions of 1, 2 and 3 from N in the following are safe.
134- let mut alpha_rev_csum = vec ! [ alpha[ N - 1 ] ; N - 1 ] ;
135- for k in 0 ..( N - 2 ) {
136- alpha_rev_csum[ N - 3 - k] = alpha_rev_csum[ N - 2 - k] + alpha[ N - 2 - k] ;
130+ let n = alpha. len ( ) ;
131+ let mut alpha_rev_csum = vec ! [ alpha[ n - 1 ] ; n - 1 ] ;
132+ for k in 0 ..( n - 2 ) {
133+ alpha_rev_csum[ n - 3 - k] = alpha_rev_csum[ n - 2 - k] + alpha[ n - 2 - k] ;
137134 }
138135
139136 // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example
@@ -142,7 +139,7 @@ where
142139 // Then pass each tuple to `Beta::new()` to create the `Beta`
143140 // instances.
144141 let mut beta_dists = Vec :: new ( ) ;
145- for ( & a, & b) in alpha[ ..( N - 1 ) ] . iter ( ) . zip ( alpha_rev_csum. iter ( ) ) {
142+ for ( & a, & b) in alpha[ ..( n - 1 ) ] . iter ( ) . zip ( alpha_rev_csum. iter ( ) ) {
146143 let dist = Beta :: new ( a, b) . map_err ( |_| DirichletFromBetaError :: BetaNewFailed ) ?;
147144 beta_dists. push ( dist) ;
148145 }
@@ -152,18 +149,19 @@ where
152149 }
153150}
154151
155- impl < F , const N : usize > MultiDistribution < F > for DirichletFromBeta < F , N >
152+ impl < F > MultiDistribution < F > for DirichletFromBeta < F >
156153where
157154 F : Float ,
158155 StandardNormal : Distribution < F > ,
159156 Exp1 : Distribution < F > ,
160157 Open01 : Distribution < F > ,
161158{
159+ #[ inline]
162160 fn sample_len ( & self ) -> usize {
163- N
161+ self . samplers . len ( ) + 1
164162 }
165163 fn sample_to_slice < R : Rng + ?Sized > ( & self , rng : & mut R , output : & mut [ F ] ) {
166- assert_eq ! ( output. len( ) , N ) ;
164+ assert_eq ! ( output. len( ) , self . sample_len ( ) ) ;
167165
168166 let mut acc = F :: one ( ) ;
169167
@@ -172,24 +170,24 @@ where
172170 * s = acc * beta_sample;
173171 acc = acc * ( F :: one ( ) - beta_sample) ;
174172 }
175- output[ N - 1 ] = acc;
173+ output[ output . len ( ) - 1 ] = acc;
176174 }
177175}
178176
179177#[ derive( Clone , Debug , PartialEq ) ]
180178#[ cfg_attr( feature = "serde" , serde_as) ]
181- enum DirichletRepr < F , const N : usize >
179+ enum DirichletRepr < F >
182180where
183181 F : Float ,
184182 StandardNormal : Distribution < F > ,
185183 Exp1 : Distribution < F > ,
186184 Open01 : Distribution < F > ,
187185{
188186 /// Dirichlet distribution that generates samples using the Gamma distribution.
189- FromGamma ( DirichletFromGamma < F , N > ) ,
187+ FromGamma ( DirichletFromGamma < F > ) ,
190188
191189 /// Dirichlet distribution that generates samples using the Beta distribution.
192- FromBeta ( DirichletFromBeta < F , N > ) ,
190+ FromBeta ( DirichletFromBeta < F > ) ,
193191}
194192
195193/// The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) `Dirichlet(α₁, α₂, ..., αₖ)`.
@@ -217,20 +215,20 @@ where
217215/// use rand_distr::multi::Dirichlet;
218216/// use rand_distr::multi::MultiDistribution;
219217///
220- /// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
218+ /// let dirichlet = Dirichlet::new(& [1.0, 2.0, 3.0]).unwrap();
221219/// let samples = dirichlet.sample(&mut rand::rng());
222- /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
220+ /// println!("{:?} is from a Dirichlet(& [1.0, 2.0, 3.0]) distribution", samples);
223221/// ```
224222#[ cfg_attr( feature = "serde" , serde_as) ]
225223#[ derive( Clone , Debug , PartialEq ) ]
226- pub struct Dirichlet < F , const N : usize >
224+ pub struct Dirichlet < F >
227225where
228226 F : Float ,
229227 StandardNormal : Distribution < F > ,
230228 Exp1 : Distribution < F > ,
231229 Open01 : Distribution < F > ,
232230{
233- repr : DirichletRepr < F , N > ,
231+ repr : DirichletRepr < F > ,
234232}
235233
236234/// Error type returned from [`Dirichlet::new`].
@@ -275,7 +273,7 @@ impl fmt::Display for Error {
275273#[ cfg( feature = "std" ) ]
276274impl std:: error:: Error for Error { }
277275
278- impl < F , const N : usize > Dirichlet < F , N >
276+ impl < F > Dirichlet < F >
279277where
280278 F : Float ,
281279 StandardNormal : Distribution < F > ,
@@ -287,8 +285,8 @@ where
287285 /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive,
288286 /// finite and not subnormal.
289287 #[ inline]
290- pub fn new ( alpha : [ F ; N ] ) -> Result < Dirichlet < F , N > , Error > {
291- if N < 2 {
288+ pub fn new ( alpha : & [ F ] ) -> Result < Dirichlet < F > , Error > {
289+ if alpha . len ( ) < 2 {
292290 return Err ( Error :: AlphaTooShort ) ;
293291 }
294292 for & ai in alpha. iter ( ) {
@@ -322,15 +320,19 @@ where
322320 }
323321}
324322
325- impl < F , const N : usize > MultiDistribution < F > for Dirichlet < F , N >
323+ impl < F > MultiDistribution < F > for Dirichlet < F >
326324where
327325 F : Float ,
328326 StandardNormal : Distribution < F > ,
329327 Exp1 : Distribution < F > ,
330328 Open01 : Distribution < F > ,
331329{
330+ #[ inline]
332331 fn sample_len ( & self ) -> usize {
333- N
332+ match & self . repr {
333+ DirichletRepr :: FromGamma ( dirichlet) => dirichlet. sample_len ( ) ,
334+ DirichletRepr :: FromBeta ( dirichlet) => dirichlet. sample_len ( ) ,
335+ }
334336 }
335337 fn sample_to_slice < R : Rng + ?Sized > ( & self , rng : & mut R , output : & mut [ F ] ) {
336338 match & self . repr {
@@ -340,7 +342,7 @@ where
340342 }
341343}
342344
343- impl < F , const N : usize > Distribution < Vec < F > > for Dirichlet < F , N >
345+ impl < F > Distribution < Vec < F > > for Dirichlet < F >
344346where
345347 F : Float + Default ,
346348 StandardNormal : Distribution < F > ,
@@ -356,7 +358,7 @@ mod test {
356358
357359 #[ test]
358360 fn test_dirichlet ( ) {
359- let d = Dirichlet :: new ( [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
361+ let d = Dirichlet :: new ( & [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
360362 let mut rng = crate :: test:: rng ( 221 ) ;
361363 let samples = d. sample ( & mut rng) ;
362364 assert ! ( samples. into_iter( ) . all( |x: f64 | x > 0.0 ) ) ;
@@ -365,42 +367,42 @@ mod test {
365367 #[ test]
366368 #[ should_panic]
367369 fn test_dirichlet_invalid_length ( ) {
368- Dirichlet :: new ( [ 0.5 ] ) . unwrap ( ) ;
370+ Dirichlet :: new ( & [ 0.5 ] ) . unwrap ( ) ;
369371 }
370372
371373 #[ test]
372374 #[ should_panic]
373375 fn test_dirichlet_alpha_zero ( ) {
374- Dirichlet :: new ( [ 0.1 , 0.0 , 0.3 ] ) . unwrap ( ) ;
376+ Dirichlet :: new ( & [ 0.1 , 0.0 , 0.3 ] ) . unwrap ( ) ;
375377 }
376378
377379 #[ test]
378380 #[ should_panic]
379381 fn test_dirichlet_alpha_negative ( ) {
380- Dirichlet :: new ( [ 0.1 , -1.5 , 0.3 ] ) . unwrap ( ) ;
382+ Dirichlet :: new ( & [ 0.1 , -1.5 , 0.3 ] ) . unwrap ( ) ;
381383 }
382384
383385 #[ test]
384386 #[ should_panic]
385387 fn test_dirichlet_alpha_nan ( ) {
386- Dirichlet :: new ( [ 0.5 , f64:: NAN , 0.25 ] ) . unwrap ( ) ;
388+ Dirichlet :: new ( & [ 0.5 , f64:: NAN , 0.25 ] ) . unwrap ( ) ;
387389 }
388390
389391 #[ test]
390392 #[ should_panic]
391393 fn test_dirichlet_alpha_subnormal ( ) {
392- Dirichlet :: new ( [ 0.5 , 1.5e-321 , 0.25 ] ) . unwrap ( ) ;
394+ Dirichlet :: new ( & [ 0.5 , 1.5e-321 , 0.25 ] ) . unwrap ( ) ;
393395 }
394396
395397 #[ test]
396398 #[ should_panic]
397399 fn test_dirichlet_alpha_inf ( ) {
398- Dirichlet :: new ( [ 0.5 , f64:: INFINITY , 0.25 ] ) . unwrap ( ) ;
400+ Dirichlet :: new ( & [ 0.5 , f64:: INFINITY , 0.25 ] ) . unwrap ( ) ;
399401 }
400402
401403 #[ test]
402404 fn dirichlet_distributions_can_be_compared ( ) {
403- assert_eq ! ( Dirichlet :: new( [ 1.0 , 2.0 ] ) , Dirichlet :: new( [ 1.0 , 2.0 ] ) ) ;
405+ assert_eq ! ( Dirichlet :: new( & [ 1.0 , 2.0 ] ) , Dirichlet :: new( & [ 1.0 , 2.0 ] ) ) ;
404406 }
405407
406408 /// Check that the means of the components of n samples from
@@ -410,7 +412,7 @@ mod test {
410412 /// This is a crude statistical test, but it will catch egregious
411413 /// mistakes. It will also also fail if any samples contain nan.
412414 fn check_dirichlet_means < const N : usize > ( alpha : [ f64 ; N ] , n : i32 , rtol : f64 , seed : u64 ) {
413- let d = Dirichlet :: new ( alpha) . unwrap ( ) ;
415+ let d = Dirichlet :: new ( & alpha) . unwrap ( ) ;
414416 let mut rng = crate :: test:: rng ( seed) ;
415417 let mut sums = [ 0.0 ; N ] ;
416418 for _ in 0 ..n {
0 commit comments