@@ -84,6 +84,7 @@ use core::fmt;
8484#[ derive( Debug , Clone ) ]
8585pub struct WeightedIndex < X : SampleUniform + PartialOrd > {
8686 cumulative_weights : Vec < X > ,
87+ total_weight : X ,
8788 weight_distribution : X :: Sampler ,
8889}
8990
@@ -125,9 +126,98 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
125126 if total_weight == zero {
126127 return Err ( WeightedError :: AllWeightsZero ) ;
127128 }
128- let distr = X :: Sampler :: new ( zero, total_weight) ;
129+ let distr = X :: Sampler :: new ( zero, total_weight. clone ( ) ) ;
129130
130- Ok ( WeightedIndex { cumulative_weights : weights, weight_distribution : distr } )
131+ Ok ( WeightedIndex { cumulative_weights : weights, total_weight, weight_distribution : distr } )
132+ }
133+
134+ /// Update a subset of weights, without changing the number of weights.
135+ ///
136+ /// `new_weights` must be sorted by the index.
137+ ///
138+ /// Using this method instead of `new` might be more efficient if only a small number of
139+ /// weights is modified. No allocations are performed, unless the weight type `X` uses
140+ /// allocation internally.
141+ ///
142+ /// In case of error, `self` is not modified.
143+ pub fn update_weights ( & mut self , new_weights : & [ ( usize , & X ) ] ) -> Result < ( ) , WeightedError >
144+ where X : for < ' a > :: core:: ops:: AddAssign < & ' a X > +
145+ for < ' a > :: core:: ops:: SubAssign < & ' a X > +
146+ Clone +
147+ Default {
148+ if new_weights. is_empty ( ) {
149+ return Ok ( ( ) ) ;
150+ }
151+
152+ let zero = <X as Default >:: default ( ) ;
153+
154+ let mut total_weight = self . total_weight . clone ( ) ;
155+
156+ // Check for errors first, so we don't modify `self` in case something
157+ // goes wrong.
158+ let mut prev_i = None ;
159+ for & ( i, w) in new_weights {
160+ if let Some ( old_i) = prev_i {
161+ if old_i >= i {
162+ return Err ( WeightedError :: InvalidWeight ) ;
163+ }
164+ }
165+ if * w < zero {
166+ return Err ( WeightedError :: InvalidWeight ) ;
167+ }
168+ if i >= self . cumulative_weights . len ( ) + 1 {
169+ return Err ( WeightedError :: TooMany ) ;
170+ }
171+
172+ let mut old_w = if i < self . cumulative_weights . len ( ) {
173+ self . cumulative_weights [ i] . clone ( )
174+ } else {
175+ self . total_weight . clone ( )
176+ } ;
177+ if i > 0 {
178+ old_w -= & self . cumulative_weights [ i - 1 ] ;
179+ }
180+
181+ total_weight -= & old_w;
182+ total_weight += w;
183+ prev_i = Some ( i) ;
184+ }
185+ if total_weight == zero {
186+ return Err ( WeightedError :: AllWeightsZero ) ;
187+ }
188+
189+ // Update the weights. Because we checked all the preconditions in the
190+ // previous loop, this should never panic.
191+ let mut iter = new_weights. iter ( ) ;
192+
193+ let mut prev_weight = zero. clone ( ) ;
194+ let mut next_new_weight = iter. next ( ) ;
195+ let & ( first_new_index, _) = next_new_weight. unwrap ( ) ;
196+ let mut cumulative_weight = if first_new_index > 0 {
197+ self . cumulative_weights [ first_new_index - 1 ] . clone ( )
198+ } else {
199+ zero. clone ( )
200+ } ;
201+ for i in first_new_index..self . cumulative_weights . len ( ) {
202+ match next_new_weight {
203+ Some ( & ( j, w) ) if i == j => {
204+ cumulative_weight += w;
205+ next_new_weight = iter. next ( ) ;
206+ } ,
207+ _ => {
208+ let mut tmp = self . cumulative_weights [ i] . clone ( ) ;
209+ tmp -= & prev_weight; // We know this is positive.
210+ cumulative_weight += & tmp;
211+ }
212+ }
213+ prev_weight = cumulative_weight. clone ( ) ;
214+ core:: mem:: swap ( & mut prev_weight, & mut self . cumulative_weights [ i] ) ;
215+ }
216+
217+ self . total_weight = total_weight;
218+ self . weight_distribution = X :: Sampler :: new ( zero, self . total_weight . clone ( ) ) ;
219+
220+ Ok ( ( ) )
131221 }
132222}
133223
@@ -201,6 +291,31 @@ mod test {
201291 assert_eq ! ( WeightedIndex :: new( & [ -10 , 20 , 1 , 30 ] ) . unwrap_err( ) , WeightedError :: InvalidWeight ) ;
202292 assert_eq ! ( WeightedIndex :: new( & [ -10 ] ) . unwrap_err( ) , WeightedError :: InvalidWeight ) ;
203293 }
294+
295+ #[ test]
296+ fn test_update_weights ( ) {
297+ let data = [
298+ ( & [ 10u32 , 2 , 3 , 4 ] [ ..] ,
299+ & [ ( 1 , & 100 ) , ( 2 , & 4 ) ] [ ..] , // positive change
300+ & [ 10 , 100 , 4 , 4 ] [ ..] ) ,
301+ ( & [ 1u32 , 2 , 3 , 0 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] [ ..] ,
302+ & [ ( 2 , & 1 ) , ( 5 , & 1 ) , ( 13 , & 100 ) ] [ ..] , // negative change and last element
303+ & [ 1u32 , 2 , 1 , 0 , 5 , 1 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 100 ] [ ..] ) ,
304+ ] ;
305+
306+ for ( weights, update, expected_weights) in data. into_iter ( ) {
307+ let total_weight = weights. iter ( ) . sum :: < u32 > ( ) ;
308+ let mut distr = WeightedIndex :: new ( weights. to_vec ( ) ) . unwrap ( ) ;
309+ assert_eq ! ( distr. total_weight, total_weight) ;
310+
311+ distr. update_weights ( update) . unwrap ( ) ;
312+ let expected_total_weight = expected_weights. iter ( ) . sum :: < u32 > ( ) ;
313+ let expected_distr = WeightedIndex :: new ( expected_weights. to_vec ( ) ) . unwrap ( ) ;
314+ assert_eq ! ( distr. total_weight, expected_total_weight) ;
315+ assert_eq ! ( distr. total_weight, expected_distr. total_weight) ;
316+ assert_eq ! ( distr. cumulative_weights, expected_distr. cumulative_weights) ;
317+ }
318+ }
204319}
205320
206321/// Error type returned from `WeightedIndex::new`.
0 commit comments