@@ -15,7 +15,7 @@ use super::utils::{Backoff, CachePadded};
1515use super :: waker:: SyncWaker ;
1616
1717use crate :: cell:: UnsafeCell ;
18- use crate :: mem:: MaybeUninit ;
18+ use crate :: mem:: { self , MaybeUninit } ;
1919use crate :: ptr;
2020use crate :: sync:: atomic:: { self , AtomicUsize , Ordering } ;
2121use crate :: time:: Instant ;
@@ -25,7 +25,8 @@ struct Slot<T> {
2525 /// The current stamp.
2626 stamp : AtomicUsize ,
2727
28- /// The message in this slot.
28+ /// The message in this slot. Either read out in `read` or dropped through
29+ /// `discard_all_messages`.
2930 msg : UnsafeCell < MaybeUninit < T > > ,
3031}
3132
@@ -439,21 +440,122 @@ impl<T> Channel<T> {
439440 Some ( self . cap )
440441 }
441442
442- /// Disconnects the channel and wakes up all blocked senders and receivers.
443+ /// Disconnects senders and wakes up all blocked receivers.
443444 ///
444445 /// Returns `true` if this call disconnected the channel.
445- pub ( crate ) fn disconnect ( & self ) -> bool {
446+ pub ( crate ) fn disconnect_senders ( & self ) -> bool {
446447 let tail = self . tail . fetch_or ( self . mark_bit , Ordering :: SeqCst ) ;
447448
448449 if tail & self . mark_bit == 0 {
449- self . senders . disconnect ( ) ;
450450 self . receivers . disconnect ( ) ;
451451 true
452452 } else {
453453 false
454454 }
455455 }
456456
457+ /// Disconnects receivers and wakes up all blocked senders.
458+ ///
459+ /// Returns `true` if this call disconnected the channel.
460+ ///
461+ /// # Safety
462+ /// May only be called once upon dropping the last receiver. The
463+ /// destruction of all other receivers must have been observed with acquire
464+ /// ordering or stronger.
465+ pub ( crate ) unsafe fn disconnect_receivers ( & self ) -> bool {
466+ let tail = self . tail . fetch_or ( self . mark_bit , Ordering :: SeqCst ) ;
467+ self . discard_all_messages ( tail) ;
468+
469+ if tail & self . mark_bit == 0 {
470+ self . senders . disconnect ( ) ;
471+ true
472+ } else {
473+ false
474+ }
475+ }
476+
477+ /// Discards all messages.
478+ ///
479+ /// `tail` should be the current (and therefore last) value of `tail`.
480+ ///
481+ /// # Safety
482+ /// This method must only be called when dropping the last receiver. The
483+ /// destruction of all other receivers must have been observed with acquire
484+ /// ordering or stronger.
485+ unsafe fn discard_all_messages ( & self , tail : usize ) {
486+ debug_assert ! ( self . is_disconnected( ) ) ;
487+
488+ /// Use a helper struct with a custom `Drop` to ensure all messages are
489+ /// dropped, even if a destructor panicks.
490+ struct DiscardState < ' a , T > {
491+ channel : & ' a Channel < T > ,
492+ head : usize ,
493+ tail : usize ,
494+ backoff : Backoff ,
495+ }
496+
497+ impl < ' a , T > DiscardState < ' a , T > {
498+ fn discard ( & mut self ) {
499+ loop {
500+ // Deconstruct the head.
501+ let index = self . head & ( self . channel . mark_bit - 1 ) ;
502+ let lap = self . head & !( self . channel . one_lap - 1 ) ;
503+
504+ // Inspect the corresponding slot.
505+ debug_assert ! ( index < self . channel. buffer. len( ) ) ;
506+ let slot = unsafe { self . channel . buffer . get_unchecked ( index) } ;
507+ let stamp = slot. stamp . load ( Ordering :: Acquire ) ;
508+
509+ // If the stamp is ahead of the head by 1, we may drop the message.
510+ if self . head + 1 == stamp {
511+ self . head = if index + 1 < self . channel . cap {
512+ // Same lap, incremented index.
513+ // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
514+ self . head + 1
515+ } else {
516+ // One lap forward, index wraps around to zero.
517+ // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
518+ lap. wrapping_add ( self . channel . one_lap )
519+ } ;
520+
521+ // We updated the head, so even if this descrutor panics,
522+ // we will not attempt to destroy the slot again.
523+ unsafe {
524+ ( * slot. msg . get ( ) ) . assume_init_drop ( ) ;
525+ }
526+ // If the tail equals the head, that means the channel is empty.
527+ } else if self . tail == self . head {
528+ return ;
529+ // Otherwise, a sender is about to write into the slot, so we need
530+ // to wait for it to update the stamp.
531+ } else {
532+ self . backoff . spin_heavy ( ) ;
533+ }
534+ }
535+ }
536+ }
537+
538+ impl < ' a , T > Drop for DiscardState < ' a , T > {
539+ fn drop ( & mut self ) {
540+ self . discard ( ) ;
541+ }
542+ }
543+
544+ let mut state = DiscardState {
545+ channel : self ,
546+ // Only receivers modify `head`, so since we are the last one,
547+ // this value will not change and will not be observed (since
548+ // no new messages can be sent after disconnection).
549+ head : self . head . load ( Ordering :: Relaxed ) ,
550+ tail : tail & !self . mark_bit ,
551+ backoff : Backoff :: new ( ) ,
552+ } ;
553+ state. discard ( ) ;
554+ // This point is only reached if no destructor panics, so all messages
555+ // have already been dropped.
556+ mem:: forget ( state) ;
557+ }
558+
457559 /// Returns `true` if the channel is disconnected.
458560 pub ( crate ) fn is_disconnected ( & self ) -> bool {
459561 self . tail . load ( Ordering :: SeqCst ) & self . mark_bit != 0
@@ -483,23 +585,3 @@ impl<T> Channel<T> {
483585 head. wrapping_add ( self . one_lap ) == tail & !self . mark_bit
484586 }
485587}
486-
487- impl < T > Drop for Channel < T > {
488- fn drop ( & mut self ) {
489- // Get the index of the head.
490- let hix = self . head . load ( Ordering :: Relaxed ) & ( self . mark_bit - 1 ) ;
491-
492- // Loop over all slots that hold a message and drop them.
493- for i in 0 ..self . len ( ) {
494- // Compute the index of the next slot holding a message.
495- let index = if hix + i < self . cap { hix + i } else { hix + i - self . cap } ;
496-
497- unsafe {
498- debug_assert ! ( index < self . buffer. len( ) ) ;
499- let slot = self . buffer . get_unchecked_mut ( index) ;
500- let msg = & mut * slot. msg . get ( ) ;
501- msg. as_mut_ptr ( ) . drop_in_place ( ) ;
502- }
503- }
504- }
505- }
0 commit comments