@@ -247,6 +247,7 @@ use super::SpecExtend;
247247/// [peek]: BinaryHeap::peek
248248/// [peek\_mut]: BinaryHeap::peek_mut
249249#[ stable( feature = "rust1" , since = "1.0.0" ) ]
250+ #[ cfg_attr( not( test) , rustc_diagnostic_item = "BinaryHeap" ) ]
250251pub struct BinaryHeap < T > {
251252 data : Vec < T > ,
252253}
@@ -275,7 +276,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
275276impl < T : Ord > Drop for PeekMut < ' _ , T > {
276277 fn drop ( & mut self ) {
277278 if self . sift {
278- self . heap . sift_down ( 0 ) ;
279+ // SAFETY: PeekMut is only instantiated for non-empty heaps.
280+ unsafe { self . heap . sift_down ( 0 ) } ;
279281 }
280282 }
281283}
@@ -431,7 +433,8 @@ impl<T: Ord> BinaryHeap<T> {
431433 self . data . pop ( ) . map ( |mut item| {
432434 if !self . is_empty ( ) {
433435 swap ( & mut item, & mut self . data [ 0 ] ) ;
434- self . sift_down_to_bottom ( 0 ) ;
436+ // SAFETY: !self.is_empty() means that self.len() > 0
437+ unsafe { self . sift_down_to_bottom ( 0 ) } ;
435438 }
436439 item
437440 } )
@@ -473,7 +476,9 @@ impl<T: Ord> BinaryHeap<T> {
473476 pub fn push ( & mut self , item : T ) {
474477 let old_len = self . len ( ) ;
475478 self . data . push ( item) ;
476- self . sift_up ( 0 , old_len) ;
479+ // SAFETY: Since we pushed a new item it means that
480+ // old_len = self.len() - 1 < self.len()
481+ unsafe { self . sift_up ( 0 , old_len) } ;
477482 }
478483
479484 /// Consumes the `BinaryHeap` and returns a vector in sorted
@@ -506,7 +511,10 @@ impl<T: Ord> BinaryHeap<T> {
506511 let ptr = self . data . as_mut_ptr ( ) ;
507512 ptr:: swap ( ptr, ptr. add ( end) ) ;
508513 }
509- self . sift_down_range ( 0 , end) ;
514+ // SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
515+ // 0 < 1 <= end <= self.len() - 1 < self.len()
516+ // Which means 0 < end and end < self.len().
517+ unsafe { self . sift_down_range ( 0 , end) } ;
510518 }
511519 self . into_vec ( )
512520 }
@@ -519,78 +527,139 @@ impl<T: Ord> BinaryHeap<T> {
519527 // the hole is filled back at the end of its scope, even on panic.
520528 // Using a hole reduces the constant factor compared to using swaps,
521529 // which involves twice as many moves.
522- fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
523- unsafe {
524- // Take out the value at `pos` and create a hole.
525- let mut hole = Hole :: new ( & mut self . data , pos) ;
526-
527- while hole. pos ( ) > start {
528- let parent = ( hole. pos ( ) - 1 ) / 2 ;
529- if hole. element ( ) <= hole. get ( parent) {
530- break ;
531- }
532- hole. move_to ( parent) ;
530+
531+ /// # Safety
532+ ///
533+ /// The caller must guarantee that `pos < self.len()`.
534+ unsafe fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
535+ // Take out the value at `pos` and create a hole.
536+ // SAFETY: The caller guarantees that pos < self.len()
537+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
538+
539+ while hole. pos ( ) > start {
540+ let parent = ( hole. pos ( ) - 1 ) / 2 ;
541+
542+ // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
543+ // and so hole.pos() - 1 can't underflow.
544+ // This guarantees that parent < hole.pos() so
545+ // it's a valid index and also != hole.pos().
546+ if hole. element ( ) <= unsafe { hole. get ( parent) } {
547+ break ;
533548 }
534- hole. pos ( )
549+
550+ // SAFETY: Same as above
551+ unsafe { hole. move_to ( parent) } ;
535552 }
553+
554+ hole. pos ( )
536555 }
537556
538557 /// Take an element at `pos` and move it down the heap,
539558 /// while its children are larger.
540- fn sift_down_range ( & mut self , pos : usize , end : usize ) {
541- unsafe {
542- let mut hole = Hole :: new ( & mut self . data , pos) ;
543- let mut child = 2 * pos + 1 ;
544- while child < end - 1 {
545- // compare with the greater of the two children
546- child += ( hole. get ( child) <= hole. get ( child + 1 ) ) as usize ;
547- // if we are already in order, stop.
548- if hole. element ( ) >= hole. get ( child) {
549- return ;
550- }
551- hole. move_to ( child) ;
552- child = 2 * hole. pos ( ) + 1 ;
553- }
554- if child == end - 1 && hole. element ( ) < hole. get ( child) {
555- hole. move_to ( child) ;
559+ ///
560+ /// # Safety
561+ ///
562+ /// The caller must guarantee that `pos < end <= self.len()`.
563+ unsafe fn sift_down_range ( & mut self , pos : usize , end : usize ) {
564+ // SAFETY: The caller guarantees that pos < end <= self.len().
565+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
566+ let mut child = 2 * hole. pos ( ) + 1 ;
567+
568+ // Loop invariant: child == 2 * hole.pos() + 1.
569+ while child <= end. saturating_sub ( 2 ) {
570+ // compare with the greater of the two children
571+ // SAFETY: child < end - 1 < self.len() and
572+ // child + 1 < end <= self.len(), so they're valid indexes.
573+ // child == 2 * hole.pos() + 1 != hole.pos() and
574+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
575+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
576+ // if T is a ZST
577+ child += unsafe { hole. get ( child) <= hole. get ( child + 1 ) } as usize ;
578+
579+ // if we are already in order, stop.
580+ // SAFETY: child is now either the old child or the old child+1
581+ // We already proven that both are < self.len() and != hole.pos()
582+ if hole. element ( ) >= unsafe { hole. get ( child) } {
583+ return ;
556584 }
585+
586+ // SAFETY: same as above.
587+ unsafe { hole. move_to ( child) } ;
588+ child = 2 * hole. pos ( ) + 1 ;
589+ }
590+
591+ // SAFETY: && short circuit, which means that in the
592+ // second condition it's already true that child == end - 1 < self.len().
593+ if child == end - 1 && hole. element ( ) < unsafe { hole. get ( child) } {
594+ // SAFETY: child is already proven to be a valid index and
595+ // child == 2 * hole.pos() + 1 != hole.pos().
596+ unsafe { hole. move_to ( child) } ;
557597 }
558598 }
559599
560- fn sift_down ( & mut self , pos : usize ) {
600+ /// # Safety
601+ ///
602+ /// The caller must guarantee that `pos < self.len()`.
603+ unsafe fn sift_down ( & mut self , pos : usize ) {
561604 let len = self . len ( ) ;
562- self . sift_down_range ( pos, len) ;
605+ // SAFETY: pos < len is guaranteed by the caller and
606+ // obviously len = self.len() <= self.len().
607+ unsafe { self . sift_down_range ( pos, len) } ;
563608 }
564609
565610 /// Take an element at `pos` and move it all the way down the heap,
566611 /// then sift it up to its position.
567612 ///
568613 /// Note: This is faster when the element is known to be large / should
569614 /// be closer to the bottom.
570- fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
615+ ///
616+ /// # Safety
617+ ///
618+ /// The caller must guarantee that `pos < self.len()`.
619+ unsafe fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
571620 let end = self . len ( ) ;
572621 let start = pos;
573- unsafe {
574- let mut hole = Hole :: new ( & mut self . data , pos) ;
575- let mut child = 2 * pos + 1 ;
576- while child < end - 1 {
577- child += ( hole. get ( child) <= hole. get ( child + 1 ) ) as usize ;
578- hole. move_to ( child) ;
579- child = 2 * hole. pos ( ) + 1 ;
580- }
581- if child == end - 1 {
582- hole. move_to ( child) ;
583- }
584- pos = hole. pos ;
622+
623+ // SAFETY: The caller guarantees that pos < self.len().
624+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
625+ let mut child = 2 * hole. pos ( ) + 1 ;
626+
627+ // Loop invariant: child == 2 * hole.pos() + 1.
628+ while child <= end. saturating_sub ( 2 ) {
629+ // SAFETY: child < end - 1 < self.len() and
630+ // child + 1 < end <= self.len(), so they're valid indexes.
631+ // child == 2 * hole.pos() + 1 != hole.pos() and
632+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
633+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
634+ // if T is a ZST
635+ child += unsafe { hole. get ( child) <= hole. get ( child + 1 ) } as usize ;
636+
637+ // SAFETY: Same as above
638+ unsafe { hole. move_to ( child) } ;
639+ child = 2 * hole. pos ( ) + 1 ;
585640 }
586- self . sift_up ( start, pos) ;
641+
642+ if child == end - 1 {
643+ // SAFETY: child == end - 1 < self.len(), so it's a valid index
644+ // and child == 2 * hole.pos() + 1 != hole.pos().
645+ unsafe { hole. move_to ( child) } ;
646+ }
647+ pos = hole. pos ( ) ;
648+ drop ( hole) ;
649+
650+ // SAFETY: pos is the position in the hole and was already proven
651+ // to be a valid index.
652+ unsafe { self . sift_up ( start, pos) } ;
587653 }
588654
589655 fn rebuild ( & mut self ) {
590656 let mut n = self . len ( ) / 2 ;
591657 while n > 0 {
592658 n -= 1 ;
593- self . sift_down ( n) ;
659+ // SAFETY: n starts from self.len() / 2 and goes down to 0.
660+ // The only case when !(n < self.len()) is if
661+ // self.len() == 0, but it's ruled out by the loop condition.
662+ unsafe { self . sift_down ( n) } ;
594663 }
595664 }
596665
0 commit comments