@@ -79,6 +79,7 @@ pub struct WeightedAliasIndex<W: AliasableWeight> {
7979 no_alias_odds : Box < [ W ] > ,
8080 uniform_index : Uniform < u32 > ,
8181 uniform_within_weight_sum : Uniform < W > ,
82+ weight_sum : W ,
8283}
8384
8485impl < W : AliasableWeight > WeightedAliasIndex < W > {
@@ -231,8 +232,42 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
231232 no_alias_odds,
232233 uniform_index,
233234 uniform_within_weight_sum,
235+ weight_sum,
234236 } )
235237 }
238+
239+ /// Reconstructs and returns the original weights used to create the distribution.
240+ ///
241+ /// `O(n)` time, where `n` is the number of weights.
242+ ///
243+ /// Note: Exact values may not be recovered if `W` is a float.
244+ pub fn weights ( & self ) -> Vec < W > {
245+ let n = self . aliases . len ( ) ;
246+
247+ // `n` was validated in the constructor.
248+ let n_converted = W :: try_from_u32_lossy ( n as u32 ) . unwrap ( ) ;
249+
250+ // pre-calculate the total contribution each index receives from serving
251+ // as an alias for other indices.
252+ let mut alias_contributions = vec ! [ W :: ZERO ; n] ;
253+ for j in 0 ..n {
254+ if self . no_alias_odds [ j] < self . weight_sum {
255+ let contribution = self . weight_sum - self . no_alias_odds [ j] ;
256+ let alias_index = self . aliases [ j] as usize ;
257+ alias_contributions[ alias_index] += contribution;
258+ }
259+ }
260+
261+ // Reconstruct each weight by combining its direct `no_alias_odds`
262+ // with its total `alias_contributions` and scaling the result.
263+ self . no_alias_odds
264+ . iter ( )
265+ . zip ( & alias_contributions)
266+ . map ( |( & no_alias_odd, & alias_contribution) | {
267+ ( no_alias_odd + alias_contribution) / n_converted
268+ } )
269+ . collect ( )
270+ }
236271}
237272
238273impl < W : AliasableWeight > Distribution < usize > for WeightedAliasIndex < W > {
@@ -271,6 +306,7 @@ where
271306 no_alias_odds : self . no_alias_odds . clone ( ) ,
272307 uniform_index : self . uniform_index ,
273308 uniform_within_weight_sum : self . uniform_within_weight_sum . clone ( ) ,
309+ weight_sum : self . weight_sum ,
274310 }
275311 }
276312}
@@ -503,6 +539,48 @@ mod test {
503539 ) ;
504540 }
505541
542+ #[ test]
543+ fn test_weights_reconstruction ( ) {
544+ // Standard integers
545+ {
546+ let weights_i32 = vec ! [ 10 , 2 , 8 , 0 , 30 , 5 ] ;
547+ let dist_i32 = WeightedAliasIndex :: new ( weights_i32. clone ( ) ) . unwrap ( ) ;
548+ assert_eq ! ( weights_i32, dist_i32. weights( ) ) ;
549+ }
550+
551+ // Uniform weights
552+ {
553+ let weights_u64 = vec ! [ 1 , 1 , 1 , 1 , 1 ] ;
554+ let dist_u64 = WeightedAliasIndex :: new ( weights_u64. clone ( ) ) . unwrap ( ) ;
555+ assert_eq ! ( weights_u64, dist_u64. weights( ) ) ;
556+ }
557+
558+ // Floating point
559+ {
560+ const EPSILON : f64 = 1e-9 ;
561+ let weights_f64 = vec ! [ 0.5 , 0.2 , 0.3 , 0.0 , 1.5 , 0.88 ] ;
562+ let dist_f64 = WeightedAliasIndex :: new ( weights_f64. clone ( ) ) . unwrap ( ) ;
563+ let reconstructed_f64 = dist_f64. weights ( ) ;
564+
565+ assert_eq ! ( weights_f64. len( ) , reconstructed_f64. len( ) ) ;
566+ for ( original, reconstructed) in weights_f64. iter ( ) . zip ( reconstructed_f64. iter ( ) ) {
567+ assert ! (
568+ f64 :: abs( original - reconstructed) < EPSILON ,
569+ "Weight reconstruction failed: original {}, reconstructed {}" ,
570+ original,
571+ reconstructed
572+ ) ;
573+ }
574+ }
575+
576+ // Single item
577+ {
578+ let weights_single = vec ! [ 42_u32 ] ;
579+ let dist_single = WeightedAliasIndex :: new ( weights_single. clone ( ) ) . unwrap ( ) ;
580+ assert_eq ! ( weights_single, dist_single. weights( ) ) ;
581+ }
582+ }
583+
506584 #[ test]
507585 fn value_stability ( ) {
508586 fn test_samples < W : AliasableWeight > (
0 commit comments