146146use core:: fmt;
147147use core:: iter:: { FromIterator , FusedIterator , InPlaceIterable , SourceIter , TrustedLen } ;
148148use core:: mem:: { self , swap, ManuallyDrop } ;
149+ use core:: num:: NonZeroUsize ;
149150use core:: ops:: { Deref , DerefMut } ;
150151use core:: ptr;
151152
@@ -279,7 +280,9 @@ pub struct BinaryHeap<T> {
279280#[ stable( feature = "binary_heap_peek_mut" , since = "1.12.0" ) ]
280281pub struct PeekMut < ' a , T : ' a + Ord > {
281282 heap : & ' a mut BinaryHeap < T > ,
282- sift : bool ,
283+ // If a set_len + sift_down are required, this is Some. If a &mut T has not
284+ // yet been exposed to peek_mut()'s caller, it's None.
285+ original_len : Option < NonZeroUsize > ,
283286}
284287
285288#[ stable( feature = "collection_debug" , since = "1.17.0" ) ]
@@ -292,7 +295,14 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
292295#[ stable( feature = "binary_heap_peek_mut" , since = "1.12.0" ) ]
293296impl < T : Ord > Drop for PeekMut < ' _ , T > {
294297 fn drop ( & mut self ) {
295- if self . sift {
298+ if let Some ( original_len) = self . original_len {
299+ // SAFETY: That's how many elements were in the Vec at the time of
300+ // the PeekMut::deref_mut call, and therefore also at the time of
301+ // the BinaryHeap::peek_mut call. Since the PeekMut did not end up
302+ // getting leaked, we are now undoing the leak amplification that
303+ // the DerefMut prepared for.
304+ unsafe { self . heap . data . set_len ( original_len. get ( ) ) } ;
305+
296306 // SAFETY: PeekMut is only instantiated for non-empty heaps.
297307 unsafe { self . heap . sift_down ( 0 ) } ;
298308 }
@@ -313,7 +323,26 @@ impl<T: Ord> Deref for PeekMut<'_, T> {
313323impl < T : Ord > DerefMut for PeekMut < ' _ , T > {
314324 fn deref_mut ( & mut self ) -> & mut T {
315325 debug_assert ! ( !self . heap. is_empty( ) ) ;
316- self . sift = true ;
326+
327+ let len = self . heap . len ( ) ;
328+ if len > 1 {
329+ // Here we preemptively leak all the rest of the underlying vector
330+ // after the currently max element. If the caller mutates the &mut T
331+ // we're about to give them, and then leaks the PeekMut, all these
332+ // elements will remain leaked. If they don't leak the PeekMut, then
333+ // either Drop or PeekMut::pop will un-leak the vector elements.
334+ //
335+ // This is technique is described throughout several other places in
336+ // the standard library as "leak amplification".
337+ unsafe {
338+ // SAFETY: len > 1 so len != 0.
339+ self . original_len = Some ( NonZeroUsize :: new_unchecked ( len) ) ;
340+ // SAFETY: len > 1 so all this does for now is leak elements,
341+ // which is safe.
342+ self . heap . data . set_len ( 1 ) ;
343+ }
344+ }
345+
317346 // SAFE: PeekMut is only instantiated for non-empty heaps
318347 unsafe { self . heap . data . get_unchecked_mut ( 0 ) }
319348 }
@@ -323,9 +352,16 @@ impl<'a, T: Ord> PeekMut<'a, T> {
323352 /// Removes the peeked value from the heap and returns it.
324353 #[ stable( feature = "binary_heap_peek_mut_pop" , since = "1.18.0" ) ]
325354 pub fn pop ( mut this : PeekMut < ' a , T > ) -> T {
326- let value = this. heap . pop ( ) . unwrap ( ) ;
327- this. sift = false ;
328- value
355+ if let Some ( original_len) = this. original_len . take ( ) {
356+ // SAFETY: This is how many elements were in the Vec at the time of
357+ // the BinaryHeap::peek_mut call.
358+ unsafe { this. heap . data . set_len ( original_len. get ( ) ) } ;
359+
360+ // Unlike in Drop, here we don't also need to do a sift_down even if
361+ // the caller could've mutated the element. It is removed from the
362+ // heap on the next line and pop() is not sensitive to its value.
363+ }
364+ this. heap . pop ( ) . unwrap ( )
329365 }
330366}
331367
@@ -398,8 +434,9 @@ impl<T: Ord> BinaryHeap<T> {
398434 /// Returns a mutable reference to the greatest item in the binary heap, or
399435 /// `None` if it is empty.
400436 ///
401- /// Note: If the `PeekMut` value is leaked, the heap may be in an
402- /// inconsistent state.
437+ /// Note: If the `PeekMut` value is leaked, some heap elements might get
438+ /// leaked along with it, but the remaining elements will remain a valid
439+ /// heap.
403440 ///
404441 /// # Examples
405442 ///
@@ -426,7 +463,7 @@ impl<T: Ord> BinaryHeap<T> {
426463 /// otherwise it's *O*(1).
427464 #[ stable( feature = "binary_heap_peek_mut" , since = "1.12.0" ) ]
428465 pub fn peek_mut ( & mut self ) -> Option < PeekMut < ' _ , T > > {
429- if self . is_empty ( ) { None } else { Some ( PeekMut { heap : self , sift : false } ) }
466+ if self . is_empty ( ) { None } else { Some ( PeekMut { heap : self , original_len : None } ) }
430467 }
431468
432469 /// Removes the greatest item from the binary heap and returns it, or `None` if it
0 commit comments