3838 html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
3939) ]
4040
41+ use std:: collections:: HashMap ;
4142use std:: fmt;
4243use std:: marker:: PhantomData ;
4344use std:: panic:: { RefUnwindSafe , UnwindSafe } ;
4445use std:: rc:: Rc ;
45- use std:: sync:: atomic:: { AtomicBool , AtomicPtr , Ordering } ;
46+ use std:: sync:: atomic:: { AtomicBool , AtomicPtr , AtomicUsize , Ordering } ;
4647use std:: sync:: { Arc , Mutex , RwLock , TryLockError } ;
4748use std:: task:: { Poll , Waker } ;
49+ use std:: thread:: { self , ThreadId } ;
4850
4951use async_task:: { Builder , Runnable } ;
5052use concurrent_queue:: ConcurrentQueue ;
@@ -369,8 +371,32 @@ impl<'a> Executor<'a> {
369371 fn schedule ( & self ) -> impl Fn ( Runnable ) + Send + Sync + ' static {
370372 let state = self . state_as_arc ( ) ;
371373
372- // TODO: If possible, push into the current local queue and notify the ticker.
373- move |runnable| {
374+ move |mut runnable| {
375+ // If possible, push into the current local queue and notify the ticker.
376+ if let Some ( local_queue) = state
377+ . local_queues
378+ . read ( )
379+ . unwrap ( )
380+ . get ( & thread:: current ( ) . id ( ) )
381+ . and_then ( |list| list. first ( ) )
382+ {
383+ match local_queue. queue . push ( runnable) {
384+ Ok ( ( ) ) => {
385+ if let Some ( waker) = state
386+ . sleepers
387+ . lock ( )
388+ . unwrap ( )
389+ . notify_runner ( local_queue. runner_id )
390+ {
391+ waker. wake ( ) ;
392+ }
393+ return ;
394+ }
395+
396+ Err ( r) => runnable = r. into_inner ( ) ,
397+ }
398+ }
399+
374400 state. queue . push ( runnable) . unwrap ( ) ;
375401 state. notify ( ) ;
376402 }
@@ -687,7 +713,9 @@ struct State {
687713 queue : ConcurrentQueue < Runnable > ,
688714
689715 /// Local queues created by runners.
690- local_queues : RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ,
716+ ///
717+ /// These are keyed by the thread that the runner originated in.
718+ local_queues : RwLock < HashMap < ThreadId , Vec < Arc < LocalQueue > > > > ,
691719
692720 /// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
693721 notified : AtomicBool ,
@@ -704,7 +732,7 @@ impl State {
704732 fn new ( ) -> State {
705733 State {
706734 queue : ConcurrentQueue :: unbounded ( ) ,
707- local_queues : RwLock :: new ( Vec :: new ( ) ) ,
735+ local_queues : RwLock :: new ( HashMap :: new ( ) ) ,
708736 notified : AtomicBool :: new ( true ) ,
709737 sleepers : Mutex :: new ( Sleepers {
710738 count : 0 ,
@@ -739,36 +767,57 @@ struct Sleepers {
739767 /// IDs and wakers of sleeping unnotified tickers.
740768 ///
741769 /// A sleeping ticker is notified when its waker is missing from this list.
742- wakers : Vec < ( usize , Waker ) > ,
770+ wakers : Vec < Sleeper > ,
743771
744772 /// Reclaimed IDs.
745773 free_ids : Vec < usize > ,
746774}
747775
776+ /// A single sleeping ticker.
777+ struct Sleeper {
778+ /// ID of the sleeping ticker.
779+ id : usize ,
780+
781+ /// Waker associated with this ticker.
782+ waker : Waker ,
783+
784+ /// Specific runner ID for targeted wakeups.
785+ runner : Option < usize > ,
786+ }
787+
748788impl Sleepers {
749789 /// Inserts a new sleeping ticker.
750- fn insert ( & mut self , waker : & Waker ) -> usize {
790+ fn insert ( & mut self , waker : & Waker , runner : Option < usize > ) -> usize {
751791 let id = match self . free_ids . pop ( ) {
752792 Some ( id) => id,
753793 None => self . count + 1 ,
754794 } ;
755795 self . count += 1 ;
756- self . wakers . push ( ( id, waker. clone ( ) ) ) ;
796+ self . wakers . push ( Sleeper {
797+ id,
798+ waker : waker. clone ( ) ,
799+ runner,
800+ } ) ;
757801 id
758802 }
759803
760804 /// Re-inserts a sleeping ticker's waker if it was notified.
761805 ///
762806 /// Returns `true` if the ticker was notified.
763- fn update ( & mut self , id : usize , waker : & Waker ) -> bool {
807+ fn update ( & mut self , id : usize , waker : & Waker , runner : Option < usize > ) -> bool {
764808 for item in & mut self . wakers {
765- if item. 0 == id {
766- item. 1 . clone_from ( waker) ;
809+ if item. id == id {
810+ debug_assert_eq ! ( item. runner, runner) ;
811+ item. waker . clone_from ( waker) ;
767812 return false ;
768813 }
769814 }
770815
771- self . wakers . push ( ( id, waker. clone ( ) ) ) ;
816+ self . wakers . push ( Sleeper {
817+ id,
818+ waker : waker. clone ( ) ,
819+ runner,
820+ } ) ;
772821 true
773822 }
774823
@@ -780,7 +829,7 @@ impl Sleepers {
780829 self . free_ids . push ( id) ;
781830
782831 for i in ( 0 ..self . wakers . len ( ) ) . rev ( ) {
783- if self . wakers [ i] . 0 == id {
832+ if self . wakers [ i] . id == id {
784833 self . wakers . remove ( i) ;
785834 return false ;
786835 }
@@ -798,7 +847,20 @@ impl Sleepers {
798847 /// If a ticker was notified already or there are no tickers, `None` will be returned.
799848 fn notify ( & mut self ) -> Option < Waker > {
800849 if self . wakers . len ( ) == self . count {
801- self . wakers . pop ( ) . map ( |item| item. 1 )
850+ self . wakers . pop ( ) . map ( |item| item. waker )
851+ } else {
852+ None
853+ }
854+ }
855+
856+ /// Notify a specific waker that was previously sleeping.
857+ fn notify_runner ( & mut self , runner : usize ) -> Option < Waker > {
858+ if let Some ( posn) = self
859+ . wakers
860+ . iter ( )
861+ . position ( |sleeper| sleeper. runner == Some ( runner) )
862+ {
863+ Some ( self . wakers . swap_remove ( posn) . waker )
802864 } else {
803865 None
804866 }
@@ -817,12 +879,28 @@ struct Ticker<'a> {
817879 /// 2a) Sleeping and unnotified.
818880 /// 2b) Sleeping and notified.
819881 sleeping : usize ,
882+
883+ /// Unique runner ID, if this is a runner.
884+ runner : Option < usize > ,
820885}
821886
822887impl Ticker < ' _ > {
823888 /// Creates a ticker.
824889 fn new ( state : & State ) -> Ticker < ' _ > {
825- Ticker { state, sleeping : 0 }
890+ Ticker {
891+ state,
892+ sleeping : 0 ,
893+ runner : None ,
894+ }
895+ }
896+
897+ /// Creates a ticker for a runner.
898+ fn for_runner ( state : & State , runner : usize ) -> Ticker < ' _ > {
899+ Ticker {
900+ state,
901+ sleeping : 0 ,
902+ runner : Some ( runner) ,
903+ }
826904 }
827905
828906 /// Moves the ticker into sleeping and unnotified state.
@@ -834,12 +912,12 @@ impl Ticker<'_> {
834912 match self . sleeping {
835913 // Move to sleeping state.
836914 0 => {
837- self . sleeping = sleepers. insert ( waker) ;
915+ self . sleeping = sleepers. insert ( waker, self . runner ) ;
838916 }
839917
840918 // Already sleeping, check if notified.
841919 id => {
842- if !sleepers. update ( id, waker) {
920+ if !sleepers. update ( id, waker, self . runner ) {
843921 return false ;
844922 }
845923 }
@@ -929,8 +1007,11 @@ struct Runner<'a> {
9291007 /// Inner ticker.
9301008 ticker : Ticker < ' a > ,
9311009
1010+ /// The ID of the thread we originated from.
1011+ origin_id : ThreadId ,
1012+
9321013 /// The local queue.
933- local : Arc < ConcurrentQueue < Runnable > > ,
1014+ local : Arc < LocalQueue > ,
9341015
9351016 /// Bumped every time a runnable task is found.
9361017 ticks : usize ,
@@ -939,16 +1020,26 @@ struct Runner<'a> {
9391020impl Runner < ' _ > {
9401021 /// Creates a runner and registers it in the executor state.
9411022 fn new ( state : & State ) -> Runner < ' _ > {
1023+ static ID_GENERATOR : AtomicUsize = AtomicUsize :: new ( 0 ) ;
1024+ let runner_id = ID_GENERATOR . fetch_add ( 1 , Ordering :: SeqCst ) ;
1025+
1026+ let origin_id = thread:: current ( ) . id ( ) ;
9421027 let runner = Runner {
9431028 state,
944- ticker : Ticker :: new ( state) ,
945- local : Arc :: new ( ConcurrentQueue :: bounded ( 512 ) ) ,
1029+ ticker : Ticker :: for_runner ( state, runner_id) ,
1030+ local : Arc :: new ( LocalQueue {
1031+ queue : ConcurrentQueue :: bounded ( 512 ) ,
1032+ runner_id,
1033+ } ) ,
9461034 ticks : 0 ,
1035+ origin_id,
9471036 } ;
9481037 state
9491038 . local_queues
9501039 . write ( )
9511040 . unwrap ( )
1041+ . entry ( origin_id)
1042+ . or_default ( )
9521043 . push ( runner. local . clone ( ) ) ;
9531044 runner
9541045 }
@@ -959,13 +1050,13 @@ impl Runner<'_> {
9591050 . ticker
9601051 . runnable_with ( || {
9611052 // Try the local queue.
962- if let Ok ( r) = self . local . pop ( ) {
1053+ if let Ok ( r) = self . local . queue . pop ( ) {
9631054 return Some ( r) ;
9641055 }
9651056
9661057 // Try stealing from the global queue.
9671058 if let Ok ( r) = self . state . queue . pop ( ) {
968- steal ( & self . state . queue , & self . local ) ;
1059+ steal ( & self . state . queue , & self . local . queue ) ;
9691060 return Some ( r) ;
9701061 }
9711062
@@ -977,7 +1068,8 @@ impl Runner<'_> {
9771068 let start = rng. usize ( ..n) ;
9781069 let iter = local_queues
9791070 . iter ( )
980- . chain ( local_queues. iter ( ) )
1071+ . flat_map ( |( _, list) | list)
1072+ . chain ( local_queues. iter ( ) . flat_map ( |( _, list) | list) )
9811073 . skip ( start)
9821074 . take ( n) ;
9831075
@@ -986,8 +1078,8 @@ impl Runner<'_> {
9861078
9871079 // Try stealing from each local queue in the list.
9881080 for local in iter {
989- steal ( local, & self . local ) ;
990- if let Ok ( r) = self . local . pop ( ) {
1081+ steal ( & local. queue , & self . local . queue ) ;
1082+ if let Ok ( r) = self . local . queue . pop ( ) {
9911083 return Some ( r) ;
9921084 }
9931085 }
@@ -1001,7 +1093,7 @@ impl Runner<'_> {
10011093
10021094 if self . ticks % 64 == 0 {
10031095 // Steal tasks from the global queue to ensure fair task scheduling.
1004- steal ( & self . state . queue , & self . local ) ;
1096+ steal ( & self . state . queue , & self . local . queue ) ;
10051097 }
10061098
10071099 runnable
@@ -1015,15 +1107,26 @@ impl Drop for Runner<'_> {
10151107 . local_queues
10161108 . write ( )
10171109 . unwrap ( )
1110+ . get_mut ( & self . origin_id )
1111+ . unwrap ( )
10181112 . retain ( |local| !Arc :: ptr_eq ( local, & self . local ) ) ;
10191113
10201114 // Re-schedule remaining tasks in the local queue.
1021- while let Ok ( r) = self . local . pop ( ) {
1115+ while let Ok ( r) = self . local . queue . pop ( ) {
10221116 r. schedule ( ) ;
10231117 }
10241118 }
10251119}
10261120
1121+ /// Data associated with a local queue.
1122+ struct LocalQueue {
1123+ /// Concurrent queue of active tasks.
1124+ queue : ConcurrentQueue < Runnable > ,
1125+
1126+ /// Unique ID associated with this runner.
1127+ runner_id : usize ,
1128+ }
1129+
10271130/// Steals some items from one queue into another.
10281131fn steal < T > ( src : & ConcurrentQueue < T > , dest : & ConcurrentQueue < T > ) {
10291132 // Half of `src`'s length rounded up.
@@ -1082,14 +1185,18 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
10821185 }
10831186
10841187 /// Debug wrapper for the local runners.
1085- struct LocalRunners < ' a > ( & ' a RwLock < Vec < Arc < ConcurrentQueue < Runnable > > > > ) ;
1188+ struct LocalRunners < ' a > ( & ' a RwLock < HashMap < ThreadId , Vec < Arc < LocalQueue > > > > ) ;
10861189
10871190 impl fmt:: Debug for LocalRunners < ' _ > {
10881191 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
10891192 match self . 0 . try_read ( ) {
10901193 Ok ( lock) => f
10911194 . debug_list ( )
1092- . entries ( lock. iter ( ) . map ( |queue| queue. len ( ) ) )
1195+ . entries (
1196+ lock. iter ( )
1197+ . flat_map ( |( _, list) | list)
1198+ . map ( |queue| queue. queue . len ( ) ) ,
1199+ )
10931200 . finish ( ) ,
10941201 Err ( TryLockError :: WouldBlock ) => f. write_str ( "<locked>" ) ,
10951202 Err ( TryLockError :: Poisoned ( _) ) => f. write_str ( "<poisoned>" ) ,
0 commit comments