@@ -275,7 +275,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
275275impl < T : Ord > Drop for PeekMut < ' _ , T > {
276276 fn drop ( & mut self ) {
277277 if self . sift {
278- self . heap . sift_down ( 0 ) ;
278+ // SAFETY: PeekMut is only instantiated for non-empty heaps.
279+ unsafe { self . heap . sift_down ( 0 ) } ;
279280 }
280281 }
281282}
@@ -431,7 +432,8 @@ impl<T: Ord> BinaryHeap<T> {
431432 self . data . pop ( ) . map ( |mut item| {
432433 if !self . is_empty ( ) {
433434 swap ( & mut item, & mut self . data [ 0 ] ) ;
434- self . sift_down_to_bottom ( 0 ) ;
435+ // SAFETY: !self.is_empty() means that self.len() > 0
436+ unsafe { self . sift_down_to_bottom ( 0 ) } ;
435437 }
436438 item
437439 } )
@@ -473,7 +475,9 @@ impl<T: Ord> BinaryHeap<T> {
473475 pub fn push ( & mut self , item : T ) {
474476 let old_len = self . len ( ) ;
475477 self . data . push ( item) ;
476- self . sift_up ( 0 , old_len) ;
478+ // SAFETY: Since we pushed a new item it means that
479+ // old_len = self.len() - 1 < self.len()
480+ unsafe { self . sift_up ( 0 , old_len) } ;
477481 }
478482
479483 /// Consumes the `BinaryHeap` and returns a vector in sorted
@@ -506,7 +510,10 @@ impl<T: Ord> BinaryHeap<T> {
506510 let ptr = self . data . as_mut_ptr ( ) ;
507511 ptr:: swap ( ptr, ptr. add ( end) ) ;
508512 }
509- self . sift_down_range ( 0 , end) ;
513+ // SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
514+ // 0 < 1 <= end <= self.len() - 1 < self.len()
515+ // Which means 0 < end and end < self.len().
516+ unsafe { self . sift_down_range ( 0 , end) } ;
510517 }
511518 self . into_vec ( )
512519 }
@@ -519,78 +526,139 @@ impl<T: Ord> BinaryHeap<T> {
519526 // the hole is filled back at the end of its scope, even on panic.
520527 // Using a hole reduces the constant factor compared to using swaps,
521528 // which involves twice as many moves.
522- fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
523- unsafe {
524- // Take out the value at `pos` and create a hole.
525- let mut hole = Hole :: new ( & mut self . data , pos) ;
526-
527- while hole. pos ( ) > start {
528- let parent = ( hole. pos ( ) - 1 ) / 2 ;
529- if hole. element ( ) <= hole. get ( parent) {
530- break ;
531- }
532- hole. move_to ( parent) ;
529+
530+ /// # Safety
531+ ///
532+ /// The caller must guarantee that `pos < self.len()`.
533+ unsafe fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
534+ // Take out the value at `pos` and create a hole.
535+ // SAFETY: The caller guarantees that pos < self.len()
536+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
537+
538+ while hole. pos ( ) > start {
539+ let parent = ( hole. pos ( ) - 1 ) / 2 ;
540+
541+ // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
542+ // and so hole.pos() - 1 can't underflow.
543+ // This guarantees that parent < hole.pos() so
544+ // it's a valid index and also != hole.pos().
545+ if hole. element ( ) <= unsafe { hole. get ( parent) } {
546+ break ;
533547 }
534- hole. pos ( )
548+
549+ // SAFETY: Same as above
550+ unsafe { hole. move_to ( parent) } ;
535551 }
552+
553+ hole. pos ( )
536554 }
537555
538556 /// Take an element at `pos` and move it down the heap,
539557 /// while its children are larger.
540- fn sift_down_range ( & mut self , pos : usize , end : usize ) {
541- unsafe {
542- let mut hole = Hole :: new ( & mut self . data , pos) ;
543- let mut child = 2 * pos + 1 ;
544- while child < end - 1 {
545- // compare with the greater of the two children
546- child += ( hole. get ( child) <= hole. get ( child + 1 ) ) as usize ;
547- // if we are already in order, stop.
548- if hole. element ( ) >= hole. get ( child) {
549- return ;
550- }
551- hole. move_to ( child) ;
552- child = 2 * hole. pos ( ) + 1 ;
553- }
554- if child == end - 1 && hole. element ( ) < hole. get ( child) {
555- hole. move_to ( child) ;
558+ ///
559+ /// # Safety
560+ ///
561+ /// The caller must guarantee that `pos < end <= self.len()`.
562+ unsafe fn sift_down_range ( & mut self , pos : usize , end : usize ) {
563+ // SAFETY: The caller guarantees that pos < end <= self.len().
564+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
565+ let mut child = 2 * hole. pos ( ) + 1 ;
566+
567+ // Loop invariant: child == 2 * hole.pos() + 1.
568+ while child < end - 1 {
569+ // compare with the greater of the two children
570+ // SAFETY: child < end - 1 < self.len() and
571+ // child + 1 < end <= self.len(), so they're valid indexes.
572+ // child == 2 * hole.pos() + 1 != hole.pos() and
573+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
574+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
575+ // if T is a ZST
576+ child += unsafe { hole. get ( child) <= hole. get ( child + 1 ) } as usize ;
577+
578+ // if we are already in order, stop.
579+ // SAFETY: child is now either the old child or the old child+1
580+ // We already proven that both are < self.len() and != hole.pos()
581+ if hole. element ( ) >= unsafe { hole. get ( child) } {
582+ return ;
556583 }
584+
585+ // SAFETY: same as above.
586+ unsafe { hole. move_to ( child) } ;
587+ child = 2 * hole. pos ( ) + 1 ;
588+ }
589+
590+ // SAFETY: && short circuit, which means that in the
591+ // second condition it's already true that child == end - 1 < self.len().
592+ if child == end - 1 && hole. element ( ) < unsafe { hole. get ( child) } {
593+ // SAFETY: child is already proven to be a valid index and
594+ // child == 2 * hole.pos() + 1 != hole.pos().
595+ unsafe { hole. move_to ( child) } ;
557596 }
558597 }
559598
560- fn sift_down ( & mut self , pos : usize ) {
599+ /// # Safety
600+ ///
601+ /// The caller must guarantee that `pos < self.len()`.
602+ unsafe fn sift_down ( & mut self , pos : usize ) {
561603 let len = self . len ( ) ;
562- self . sift_down_range ( pos, len) ;
604+ // SAFETY: pos < len is guaranteed by the caller and
605+ // obviously len = self.len() <= self.len().
606+ unsafe { self . sift_down_range ( pos, len) } ;
563607 }
564608
565609 /// Take an element at `pos` and move it all the way down the heap,
566610 /// then sift it up to its position.
567611 ///
568612 /// Note: This is faster when the element is known to be large / should
569613 /// be closer to the bottom.
570- fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
614+ ///
615+ /// # Safety
616+ ///
617+ /// The caller must guarantee that `pos < self.len()`.
618+ unsafe fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
571619 let end = self . len ( ) ;
572620 let start = pos;
573- unsafe {
574- let mut hole = Hole :: new ( & mut self . data , pos) ;
575- let mut child = 2 * pos + 1 ;
576- while child < end - 1 {
577- child += ( hole. get ( child) <= hole. get ( child + 1 ) ) as usize ;
578- hole. move_to ( child) ;
579- child = 2 * hole. pos ( ) + 1 ;
580- }
581- if child == end - 1 {
582- hole. move_to ( child) ;
583- }
584- pos = hole. pos ;
621+
622+ // SAFETY: The caller guarantees that pos < self.len().
623+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
624+ let mut child = 2 * hole. pos ( ) + 1 ;
625+
626+ // Loop invariant: child == 2 * hole.pos() + 1.
627+ while child < end - 1 {
628+ // SAFETY: child < end - 1 < self.len() and
629+ // child + 1 < end <= self.len(), so they're valid indexes.
630+ // child == 2 * hole.pos() + 1 != hole.pos() and
631+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
632+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
633+ // if T is a ZST
634+ child += unsafe { hole. get ( child) <= hole. get ( child + 1 ) } as usize ;
635+
636+ // SAFETY: Same as above
637+ unsafe { hole. move_to ( child) } ;
638+ child = 2 * hole. pos ( ) + 1 ;
585639 }
586- self . sift_up ( start, pos) ;
640+
641+ if child == end - 1 {
642+ // SAFETY: child == end - 1 < self.len(), so it's a valid index
643+ // and child == 2 * hole.pos() + 1 != hole.pos().
644+ unsafe { hole. move_to ( child) } ;
645+ }
646+ pos = hole. pos ( ) ;
647+ drop ( hole) ;
648+
649+ // SAFETY: pos is the position in the hole and was already proven
650+ // to be a valid index.
651+ unsafe { self . sift_up ( start, pos) } ;
587652 }
588653
589654 fn rebuild ( & mut self ) {
590655 let mut n = self . len ( ) / 2 ;
591656 while n > 0 {
592657 n -= 1 ;
593- self . sift_down ( n) ;
658+ // SAFETY: n starts from self.len() / 2 and goes down to 0.
659+ // The only case when !(n < self.len()) is if
660+ // self.len() == 0, but it's ruled out by the loop condition.
661+ unsafe { self . sift_down ( n) } ;
594662 }
595663 }
596664
0 commit comments