@@ -9,7 +9,7 @@ use std::pin::Pin;
99use std:: ptr;
1010use std:: sync:: atomic:: AtomicUsize ;
1111use std:: sync:: atomic:: Ordering :: { Acquire , SeqCst } ;
12- use std:: sync:: { Arc , Mutex , Weak } ;
12+ use std:: sync:: { Arc , Mutex , MutexGuard , Weak } ;
1313
1414/// Future for the [`shared`](super::FutureExt::shared) method.
1515#[ must_use = "futures do nothing unless you `.await` or poll them" ]
@@ -81,6 +81,7 @@ const IDLE: usize = 0;
8181const POLLING : usize = 1 ;
8282const COMPLETE : usize = 2 ;
8383const POISONED : usize = 3 ;
84+ const WOKEN_DURING_POLLING : usize = 4 ;
8485
8586const NULL_WAKER_KEY : usize = usize:: MAX ;
8687
@@ -197,36 +198,47 @@ where
197198 }
198199}
199200
200- impl < Fut > Inner < Fut >
201- where
202- Fut : Future ,
203- Fut :: Output : Clone ,
204- {
205- /// Registers the current task to receive a wakeup when we are awoken.
206- fn record_waker ( & self , waker_key : & mut usize , cx : & mut Context < ' _ > ) {
207- let mut wakers_guard = self . notifier . wakers . lock ( ) . unwrap ( ) ;
208-
209- let wakers_mut = wakers_guard. as_mut ( ) ;
210-
211- let wakers = match wakers_mut {
212- Some ( wakers) => wakers,
213- None => return ,
214- } ;
215-
216- let new_waker = cx. waker ( ) ;
201+ /// Registers the current task to receive a wakeup when we are awoken.
202+ fn record_waker (
203+ wakers_guard : & mut MutexGuard < ' _ , Option < Slab < Option < Waker > > > > ,
204+ waker_key : & mut usize ,
205+ cx : & mut Context < ' _ > ,
206+ ) {
207+ let wakers = match wakers_guard. as_mut ( ) {
208+ Some ( wakers) => wakers,
209+ None => return ,
210+ } ;
211+
212+ let new_waker = cx. waker ( ) ;
213+
214+ if * waker_key == NULL_WAKER_KEY {
215+ * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ;
216+ } else {
217+ match wakers[ * waker_key] {
218+ Some ( ref old_waker) if new_waker. will_wake ( old_waker) => { }
219+ // Could use clone_from here, but Waker doesn't specialize it.
220+ ref mut slot => * slot = Some ( new_waker. clone ( ) ) ,
221+ }
222+ }
223+ debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ;
224+ }
217225
218- if * waker_key == NULL_WAKER_KEY {
219- * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ;
220- } else {
221- match wakers[ * waker_key] {
222- Some ( ref old_waker) if new_waker. will_wake ( old_waker) => { }
223- // Could use clone_from here, but Waker doesn't specialize it.
224- ref mut slot => * slot = Some ( new_waker. clone ( ) ) ,
226+ /// Wakes all tasks that are registered to be woken.
227+ fn wake_all ( waker_guard : & mut MutexGuard < ' _ , Option < Slab < Option < Waker > > > > ) {
228+ if let Some ( wakers) = waker_guard. as_mut ( ) {
229+ for ( _key, opt_waker) in wakers {
230+ if let Some ( waker) = opt_waker. take ( ) {
231+ waker. wake ( ) ;
225232 }
226233 }
227- debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ;
228234 }
235+ }
229236
237+ impl < Fut > Inner < Fut >
238+ where
239+ Fut : Future ,
240+ Fut :: Output : Clone ,
241+ {
230242 /// Safety: callers must first ensure that `inner.state`
231243 /// is `COMPLETE`
232244 unsafe fn take_or_clone_output ( self : Arc < Self > ) -> Fut :: Output {
@@ -268,18 +280,22 @@ where
268280 return unsafe { Poll :: Ready ( inner. take_or_clone_output ( ) ) } ;
269281 }
270282
271- inner. record_waker ( & mut this. waker_key , cx) ;
283+ // Guard the state transition with mutex too
284+ let mut wakers_guard = inner. notifier . wakers . lock ( ) . unwrap ( ) ;
285+ record_waker ( & mut wakers_guard, & mut this. waker_key , cx) ;
272286
273- match inner
287+ let prev = inner
274288 . notifier
275289 . state
276290 . compare_exchange ( IDLE , POLLING , SeqCst , SeqCst )
277- . unwrap_or_else ( |x| x)
278- {
291+ . unwrap_or_else ( |x| x) ;
292+ drop ( wakers_guard) ;
293+
294+ match prev {
279295 IDLE => {
280296 // Lock acquired, fall through
281297 }
282- POLLING => {
298+ POLLING | WOKEN_DURING_POLLING => {
283299 // Another task is currently polling, at this point we just want
284300 // to ensure that the waker for this task is registered
285301 this. inner = Some ( inner) ;
@@ -324,15 +340,21 @@ where
324340
325341 match poll_result {
326342 Poll :: Pending => {
327- if inner. notifier . state . compare_exchange ( POLLING , IDLE , SeqCst , SeqCst ) . is_ok ( )
328- {
329- // Success
330- drop ( reset) ;
331- this. inner = Some ( inner) ;
332- return Poll :: Pending ;
333- } else {
334- unreachable ! ( )
343+ match inner. notifier . state . compare_exchange ( POLLING , IDLE , SeqCst , SeqCst ) {
344+ Ok ( POLLING ) => { } // success
345+ Err ( WOKEN_DURING_POLLING ) => {
346+ // waker has been called inside future.poll, need to wake any new wakers registered
347+ let mut wakers = inner. notifier . wakers . lock ( ) . unwrap ( ) ;
348+ wake_all ( & mut wakers) ;
349+ let prev = inner. notifier . state . swap ( IDLE , SeqCst ) ;
350+ assert_eq ! ( prev, WOKEN_DURING_POLLING ) ;
351+ drop ( wakers) ;
352+ }
353+ _ => unreachable ! ( ) ,
335354 }
355+ drop ( reset) ;
356+ this. inner = Some ( inner) ;
357+ return Poll :: Pending ;
336358 }
337359 Poll :: Ready ( output) => output,
338360 }
@@ -387,14 +409,9 @@ where
387409
388410impl ArcWake for Notifier {
389411 fn wake_by_ref ( arc_self : & Arc < Self > ) {
390- let wakers = & mut * arc_self. wakers . lock ( ) . unwrap ( ) ;
391- if let Some ( wakers) = wakers. as_mut ( ) {
392- for ( _key, opt_waker) in wakers {
393- if let Some ( waker) = opt_waker. take ( ) {
394- waker. wake ( ) ;
395- }
396- }
397- }
412+ let mut wakers = arc_self. wakers . lock ( ) . unwrap ( ) ;
413+ let _ = arc_self. state . compare_exchange ( POLLING , WOKEN_DURING_POLLING , SeqCst , SeqCst ) ;
414+ wake_all ( & mut wakers) ;
398415 }
399416}
400417
0 commit comments