3939) ]
4040#![ cfg_attr( docsrs, feature( doc_auto_cfg) ) ]
4141
42+ use std:: collections:: HashMap ;
4243use std:: fmt;
4344use std:: marker:: PhantomData ;
4445use std:: panic:: { RefUnwindSafe , UnwindSafe } ;
4546use std:: rc:: Rc ;
46- use std:: sync:: atomic:: { AtomicBool , AtomicPtr , Ordering } ;
47+ use std:: sync:: atomic:: { AtomicBool , AtomicPtr , AtomicUsize , Ordering } ;
4748use std:: sync:: { Arc , Mutex , RwLock , TryLockError } ;
4849use std:: task:: { Poll , Waker } ;
50+ use std:: thread:: { self , ThreadId } ;
4951
5052use async_task:: { Builder , Runnable } ;
5153use concurrent_queue:: ConcurrentQueue ;
@@ -347,8 +349,32 @@ impl<'a> Executor<'a> {
347349 fn schedule ( & self ) -> impl Fn ( Runnable ) + Send + Sync + ' static {
348350 let state = self . state_as_arc ( ) ;
349351
350- // TODO: If possible, push into the current local queue and notify the ticker.
351- move |runnable| {
352+ move |mut runnable| {
353+ // If possible, push into the current local queue and notify the ticker.
354+ if let Some ( local_queue) = state
355+ . local_queues
356+ . read ( )
357+ . unwrap ( )
358+ . get ( & thread:: current ( ) . id ( ) )
359+ . and_then ( |list| list. first ( ) )
360+ {
361+ match local_queue. queue . push ( runnable) {
362+ Ok ( ( ) ) => {
363+ if let Some ( waker) = state
364+ . sleepers
365+ . lock ( )
366+ . unwrap ( )
367+ . notify_runner ( local_queue. runner_id )
368+ {
369+ waker. wake ( ) ;
370+ }
371+ return ;
372+ }
373+
374+ Err ( r) => runnable = r. into_inner ( ) ,
375+ }
376+ }
377+
352378 state. queue . push ( runnable) . unwrap ( ) ;
353379 state. notify ( ) ;
354380 }
@@ -665,7 +691,9 @@ struct State {
665691 queue : ConcurrentQueue < Runnable > ,
666692
667693 /// Local queues created by runners.
668- local_queues : RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ,
694+ ///
695+ /// These are keyed by the thread that the runner originated in.
696+ local_queues : RwLock < HashMap < ThreadId , Vec < Arc < LocalQueue > > > > ,
669697
670698 /// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
671699 notified : AtomicBool ,
@@ -682,7 +710,7 @@ impl State {
682710 const fn new ( ) -> State {
683711 State {
684712 queue : ConcurrentQueue :: unbounded ( ) ,
685- local_queues : RwLock :: new ( Vec :: new ( ) ) ,
713+ local_queues : RwLock :: new ( HashMap :: new ( ) ) ,
686714 notified : AtomicBool :: new ( true ) ,
687715 sleepers : Mutex :: new ( Sleepers {
688716 count : 0 ,
@@ -756,36 +784,57 @@ struct Sleepers {
756784 /// IDs and wakers of sleeping unnotified tickers.
757785 ///
758786 /// A sleeping ticker is notified when its waker is missing from this list.
759- wakers : Vec < ( usize , Waker ) > ,
787+ wakers : Vec < Sleeper > ,
760788
761789 /// Reclaimed IDs.
762790 free_ids : Vec < usize > ,
763791}
764792
793+ /// A single sleeping ticker.
794+ struct Sleeper {
795+ /// ID of the sleeping ticker.
796+ id : usize ,
797+
798+ /// Waker associated with this ticker.
799+ waker : Waker ,
800+
801+ /// Specific runner ID for targeted wakeups.
802+ runner : Option < usize > ,
803+ }
804+
765805impl Sleepers {
766806 /// Inserts a new sleeping ticker.
767- fn insert ( & mut self , waker : & Waker ) -> usize {
807+ fn insert ( & mut self , waker : & Waker , runner : Option < usize > ) -> usize {
768808 let id = match self . free_ids . pop ( ) {
769809 Some ( id) => id,
770810 None => self . count + 1 ,
771811 } ;
772812 self . count += 1 ;
773- self . wakers . push ( ( id, waker. clone ( ) ) ) ;
813+ self . wakers . push ( Sleeper {
814+ id,
815+ waker : waker. clone ( ) ,
816+ runner,
817+ } ) ;
774818 id
775819 }
776820
777821 /// Re-inserts a sleeping ticker's waker if it was notified.
778822 ///
779823 /// Returns `true` if the ticker was notified.
780- fn update ( & mut self , id : usize , waker : & Waker ) -> bool {
824+ fn update ( & mut self , id : usize , waker : & Waker , runner : Option < usize > ) -> bool {
781825 for item in & mut self . wakers {
782- if item. 0 == id {
783- item. 1 . clone_from ( waker) ;
826+ if item. id == id {
827+ debug_assert_eq ! ( item. runner, runner) ;
828+ item. waker . clone_from ( waker) ;
784829 return false ;
785830 }
786831 }
787832
788- self . wakers . push ( ( id, waker. clone ( ) ) ) ;
833+ self . wakers . push ( Sleeper {
834+ id,
835+ waker : waker. clone ( ) ,
836+ runner,
837+ } ) ;
789838 true
790839 }
791840
@@ -797,7 +846,7 @@ impl Sleepers {
797846 self . free_ids . push ( id) ;
798847
799848 for i in ( 0 ..self . wakers . len ( ) ) . rev ( ) {
800- if self . wakers [ i] . 0 == id {
849+ if self . wakers [ i] . id == id {
801850 self . wakers . remove ( i) ;
802851 return false ;
803852 }
@@ -815,7 +864,20 @@ impl Sleepers {
815864 /// If a ticker was notified already or there are no tickers, `None` will be returned.
816865 fn notify ( & mut self ) -> Option < Waker > {
817866 if self . wakers . len ( ) == self . count {
818- self . wakers . pop ( ) . map ( |item| item. 1 )
867+ self . wakers . pop ( ) . map ( |item| item. waker )
868+ } else {
869+ None
870+ }
871+ }
872+
873+ /// Notify a specific waker that was previously sleeping.
874+ fn notify_runner ( & mut self , runner : usize ) -> Option < Waker > {
875+ if let Some ( posn) = self
876+ . wakers
877+ . iter ( )
878+ . position ( |sleeper| sleeper. runner == Some ( runner) )
879+ {
880+ Some ( self . wakers . swap_remove ( posn) . waker )
819881 } else {
820882 None
821883 }
@@ -834,12 +896,28 @@ struct Ticker<'a> {
834896 /// 2a) Sleeping and unnotified.
835897 /// 2b) Sleeping and notified.
836898 sleeping : usize ,
899+
900+ /// Unique runner ID, if this is a runner.
901+ runner : Option < usize > ,
837902}
838903
839904impl Ticker < ' _ > {
840905 /// Creates a ticker.
841906 fn new ( state : & State ) -> Ticker < ' _ > {
842- Ticker { state, sleeping : 0 }
907+ Ticker {
908+ state,
909+ sleeping : 0 ,
910+ runner : None ,
911+ }
912+ }
913+
914+ /// Creates a ticker for a runner.
915+ fn for_runner ( state : & State , runner : usize ) -> Ticker < ' _ > {
916+ Ticker {
917+ state,
918+ sleeping : 0 ,
919+ runner : Some ( runner) ,
920+ }
843921 }
844922
845923 /// Moves the ticker into sleeping and unnotified state.
@@ -851,12 +929,12 @@ impl Ticker<'_> {
851929 match self . sleeping {
852930 // Move to sleeping state.
853931 0 => {
854- self . sleeping = sleepers. insert ( waker) ;
932+ self . sleeping = sleepers. insert ( waker, self . runner ) ;
855933 }
856934
857935 // Already sleeping, check if notified.
858936 id => {
859- if !sleepers. update ( id, waker) {
937+ if !sleepers. update ( id, waker, self . runner ) {
860938 return false ;
861939 }
862940 }
@@ -946,8 +1024,11 @@ struct Runner<'a> {
9461024 /// Inner ticker.
9471025 ticker : Ticker < ' a > ,
9481026
1027+ /// The ID of the thread we originated from.
1028+ origin_id : ThreadId ,
1029+
9491030 /// The local queue.
950- local : Arc < ConcurrentQueue < Runnable > > ,
1031+ local : Arc < LocalQueue > ,
9511032
9521033 /// Bumped every time a runnable task is found.
9531034 ticks : usize ,
@@ -956,16 +1037,26 @@ struct Runner<'a> {
9561037impl Runner < ' _ > {
9571038 /// Creates a runner and registers it in the executor state.
9581039 fn new ( state : & State ) -> Runner < ' _ > {
1040+ static ID_GENERATOR : AtomicUsize = AtomicUsize :: new ( 0 ) ;
1041+ let runner_id = ID_GENERATOR . fetch_add ( 1 , Ordering :: SeqCst ) ;
1042+
1043+ let origin_id = thread:: current ( ) . id ( ) ;
9591044 let runner = Runner {
9601045 state,
961- ticker : Ticker :: new ( state) ,
962- local : Arc :: new ( ConcurrentQueue :: bounded ( 512 ) ) ,
1046+ ticker : Ticker :: for_runner ( state, runner_id) ,
1047+ local : Arc :: new ( LocalQueue {
1048+ queue : ConcurrentQueue :: bounded ( 512 ) ,
1049+ runner_id,
1050+ } ) ,
9631051 ticks : 0 ,
1052+ origin_id,
9641053 } ;
9651054 state
9661055 . local_queues
9671056 . write ( )
9681057 . unwrap ( )
1058+ . entry ( origin_id)
1059+ . or_default ( )
9691060 . push ( runner. local . clone ( ) ) ;
9701061 runner
9711062 }
@@ -976,13 +1067,13 @@ impl Runner<'_> {
9761067 . ticker
9771068 . runnable_with ( || {
9781069 // Try the local queue.
979- if let Ok ( r) = self . local . pop ( ) {
1070+ if let Ok ( r) = self . local . queue . pop ( ) {
9801071 return Some ( r) ;
9811072 }
9821073
9831074 // Try stealing from the global queue.
9841075 if let Ok ( r) = self . state . queue . pop ( ) {
985- steal ( & self . state . queue , & self . local ) ;
1076+ steal ( & self . state . queue , & self . local . queue ) ;
9861077 return Some ( r) ;
9871078 }
9881079
@@ -994,7 +1085,8 @@ impl Runner<'_> {
9941085 let start = rng. usize ( ..n) ;
9951086 let iter = local_queues
9961087 . iter ( )
997- . chain ( local_queues. iter ( ) )
1088+ . flat_map ( |( _, list) | list)
1089+ . chain ( local_queues. iter ( ) . flat_map ( |( _, list) | list) )
9981090 . skip ( start)
9991091 . take ( n) ;
10001092
@@ -1003,8 +1095,8 @@ impl Runner<'_> {
10031095
10041096 // Try stealing from each local queue in the list.
10051097 for local in iter {
1006- steal ( local, & self . local ) ;
1007- if let Ok ( r) = self . local . pop ( ) {
1098+ steal ( & local. queue , & self . local . queue ) ;
1099+ if let Ok ( r) = self . local . queue . pop ( ) {
10081100 return Some ( r) ;
10091101 }
10101102 }
@@ -1018,7 +1110,7 @@ impl Runner<'_> {
10181110
10191111 if self . ticks % 64 == 0 {
10201112 // Steal tasks from the global queue to ensure fair task scheduling.
1021- steal ( & self . state . queue , & self . local ) ;
1113+ steal ( & self . state . queue , & self . local . queue ) ;
10221114 }
10231115
10241116 runnable
@@ -1032,15 +1124,26 @@ impl Drop for Runner<'_> {
10321124 . local_queues
10331125 . write ( )
10341126 . unwrap ( )
1127+ . get_mut ( & self . origin_id )
1128+ . unwrap ( )
10351129 . retain ( |local| !Arc :: ptr_eq ( local, & self . local ) ) ;
10361130
10371131 // Re-schedule remaining tasks in the local queue.
1038- while let Ok ( r) = self . local . pop ( ) {
1132+ while let Ok ( r) = self . local . queue . pop ( ) {
10391133 r. schedule ( ) ;
10401134 }
10411135 }
10421136}
10431137
1138+ /// Data associated with a local queue.
1139+ struct LocalQueue {
1140+ /// Concurrent queue of active tasks.
1141+ queue : ConcurrentQueue < Runnable > ,
1142+
1143+ /// Unique ID associated with this runner.
1144+ runner_id : usize ,
1145+ }
1146+
10441147/// Steals some items from one queue into another.
10451148fn steal < T > ( src : & ConcurrentQueue < T > , dest : & ConcurrentQueue < T > ) {
10461149 // Half of `src`'s length rounded up.
@@ -1104,14 +1207,18 @@ fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Re
11041207 }
11051208
11061209 /// Debug wrapper for the local runners.
1107- struct LocalRunners < ' a > ( & ' a RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ) ;
1210+ struct LocalRunners < ' a > ( & ' a RwLock < HashMap < ThreadId , Vec < Arc < LocalQueue > > > > ) ;
11081211
11091212 impl fmt:: Debug for LocalRunners < ' _ > {
11101213 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
11111214 match self . 0 . try_read ( ) {
11121215 Ok ( lock) => f
11131216 . debug_list ( )
1114- . entries ( lock. iter ( ) . map ( |queue| queue. len ( ) ) )
1217+ . entries (
1218+ lock. iter ( )
1219+ . flat_map ( |( _, list) | list)
1220+ . map ( |queue| queue. queue . len ( ) ) ,
1221+ )
11151222 . finish ( ) ,
11161223 Err ( TryLockError :: WouldBlock ) => f. write_str ( "<locked>" ) ,
11171224 Err ( TryLockError :: Poisoned ( _) ) => f. write_str ( "<poisoned>" ) ,
0 commit comments