@@ -54,7 +54,9 @@ module Control.Monad.IOSim.Internal
5454import Prelude hiding (read )
5555
5656import Data.Dynamic
57- import Data.Foldable (toList , traverse_ )
57+ import Data.Foldable (toList , traverse_ , foldlM )
58+ import Deque.Strict (Deque )
59+ import qualified Deque.Strict as Deque
5860import qualified Data.List as List
5961import qualified Data.List.Trace as Trace
6062import Data.Map.Strict (Map )
@@ -65,16 +67,13 @@ import qualified Data.OrdPSQ as PSQ
6567import Data.Set (Set )
6668import qualified Data.Set as Set
6769import Data.Time (UTCTime (.. ), fromGregorian )
68- import Deque.Strict (Deque )
69- import qualified Deque.Strict as Deque
7070
7171import GHC.Exts (fromList )
7272import GHC.Conc (ThreadStatus (.. ), BlockReason (.. ))
7373
74- import Control.Exception (NonTermination (.. ), assert , throw )
75- import Control.Monad (join )
76-
77- import Control.Monad (when )
74+ import Control.Exception
75+ (NonTermination (.. ), assert , throw , AsyncException (.. ))
76+ import Control.Monad (join , when )
7877import Control.Monad.ST.Lazy
7978import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST , unsafeInterleaveST )
8079import Data.STRef.Lazy
@@ -126,6 +125,7 @@ data TimerCompletionInfo s =
126125 Timer ! (TVar s TimeoutState )
127126 | TimerRegisterDelay ! (TVar s Bool )
128127 | TimerThreadDelay ! ThreadId
128+ | TimerTimeout ! ThreadId ! TimeoutId ! (STRef s IsLocked )
129129
130130-- | Internal state.
131131--
@@ -138,7 +138,7 @@ data SimState s a = SimState {
138138 finished :: ! (Map ThreadId FinishedReason ),
139139 -- | current time
140140 curTime :: ! Time ,
141- -- | ordered list of timers
141+ -- | ordered list of timers and timeouts
142142 timers :: ! (OrdPSQ TimeoutId Time (TimerCompletionInfo s )),
143143 -- | list of clocks
144144 clocks :: ! (Map ClockId UTCTime ),
@@ -235,8 +235,53 @@ schedule !thread@Thread{
235235 let thread' = thread { threadControl = ThreadControl (k x) ctl' }
236236 schedule thread' simstate
237237
238+ TimeoutFrame tmid isLockedRef k ctl' -> do
239+ -- There is a possible race between timeout action and the timeout expiration.
240+ -- We use a lock to solve the race.
241+ --
242+ -- The lock starts 'NotLocked' and when the timeout fires the lock is
243+ -- locked and asynchronously an assassin thread is coming to interrupt
244+ -- it. If the lock is locked when the timeout is fired then nothing
245+ -- happens.
246+ --
247+ -- Knowing this, if we reached this point in the code and the lock is
248+ -- 'Locked', then it means that this thread still hasn't received the
249+ -- 'TimeoutException', so we need to kill the thread that is responsible
250+ -- for doing that (the assassin thread, we need to defend ourselves!)
251+ -- and run our continuation successfully and peacefully. We will do that
252+ -- by uninterruptibly-masking ourselves so we can not receive any
253+ -- exception and kill the assassin thread behind its back.
254+ -- If the lock is 'NotLocked' then it means we can just acquire it and
255+ -- carry on with the success case.
256+ locked <- readSTRef isLockedRef
257+ case locked of
258+ Locked etid -> do
259+ let -- Kill the assassin throwing thread and carry on the
260+ -- continuation
261+ thread' =
262+ thread { threadControl =
263+ ThreadControl (ThrowTo (toException ThreadKilled )
264+ etid
265+ (k (Just x)))
266+ ctl'
267+ , threadMasking = MaskedUninterruptible
268+ }
269+ schedule thread' simstate
270+
271+ NotLocked -> do
272+ -- Acquire lock
273+ writeSTRef isLockedRef (Locked tid)
274+
275+ -- Remove the timer from the queue
276+ let timers' = PSQ. delete tmid timers
277+ -- Run the continuation
278+ thread' = thread { threadControl = ThreadControl (k (Just x)) ctl' }
279+
280+ schedule thread' simstate { timers = timers'
281+ }
238282 Throw thrower e -> {-# SCC "schedule.Throw" #-}
239283 case unwindControlStack e thread of
284+ -- Found a CatchFrame
240285 Right thread'@ Thread { threadMasking = maskst' } -> do
241286 -- We found a suitable exception handler, continue with that
242287 trace <- schedule thread' simstate
@@ -360,6 +405,23 @@ schedule !thread@Thread{
360405 , nextTmid = succ nextTmid }
361406 return (SimTrace time tid tlbl (EventTimerCreated nextTmid nextVid expiry) trace)
362407
408+ -- This case is guarded by checks in 'timeout' itself.
409+ StartTimeout d _ _ | d <= 0 ->
410+ error " schedule: StartTimeout: Impossible happened"
411+
412+ StartTimeout d action' k ->
413+ {-# SCC "schedule.StartTimeout" #-} do
414+ isLockedRef <- newSTRef NotLocked
415+ let ! expiry = d `addTime` time
416+ ! timers' = PSQ. insert nextTmid expiry (TimerTimeout tid nextTmid isLockedRef) timers
417+ ! thread' = thread { threadControl =
418+ ThreadControl action'
419+ (TimeoutFrame nextTmid isLockedRef k ctl)
420+ }
421+ ! trace <- deschedule Yield thread' simstate { timers = timers'
422+ , nextTmid = succ nextTmid }
423+ return (SimTrace time tid tlbl (EventTimeoutCreated nextTmid tid expiry) trace)
424+
363425 RegisterDelay d k | d < 0 ->
364426 {-# SCC "schedule.NewRegisterDelay" #-} do
365427 ! tvar <- execNewTVar nextVid
@@ -404,7 +466,6 @@ schedule !thread@Thread{
404466 , nextTmid = succ nextTmid }
405467 return (SimTrace time tid tlbl (EventThreadDelay expiry) trace)
406468
407-
408469 -- we do not follow `GHC.Event` behaviour here; updating a timer to the past
409470 -- effectively cancels it.
410471 UpdateTimeout (Timeout _tvar tmid) d k | d < 0 ->
@@ -777,8 +838,23 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
777838 wakeup = wakeupThreadDelay ++ wakeupSTM
778839 (_, ! simstate') = unblockThreads wakeup simstate
779840
780- ! trace <- reschedule simstate' { curTime = time'
781- , timers = timers' }
841+ -- For each 'timeout' action where the timeout has fired, start a
842+ -- new thread to execute throwTo to interrupt the action.
843+ ! timeoutExpired = [ (tid, tmid, isLockedRef)
844+ | TimerTimeout tid tmid isLockedRef <- fired ]
845+
846+ -- Get the isLockedRef values
847+ ! timeoutExpired' <- traverse (\ (tid, tmid, isLockedRef) -> do
848+ locked <- readSTRef isLockedRef
849+ return (tid, tmid, isLockedRef, locked)
850+ )
851+ timeoutExpired
852+
853+ ! simstate'' <- forkTimeoutInterruptThreads timeoutExpired' simstate'
854+
855+ ! trace <- reschedule simstate'' { curTime = time'
856+ , timers = timers' }
857+
782858 return $
783859 traceMany ([ ( time', ThreadId [- 1 ], Just " timer"
784860 , EventTimerFired tmid)
@@ -792,7 +868,13 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
792868 , let Just vids = Set. toList <$> Map. lookup tid' wokeby ]
793869 ++ [ ( time', tid, Just " thread delay timer"
794870 , EventThreadDelayFired )
795- | tid <- wakeupThreadDelay ])
871+ | tid <- wakeupThreadDelay ]
872+ ++ [ ( time', tid, Just " timeout timer"
873+ , EventTimeoutFired tmid)
874+ | (tid, tmid, _, _) <- timeoutExpired' ]
875+ ++ [ ( time', tid, Just " thread forked"
876+ , EventThreadForked tid)
877+ | (tid, _, _, _) <- timeoutExpired' ])
796878 trace
797879 where
798880 timeoutSTMAction (Timer var) = do
@@ -804,7 +886,8 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
804886 timeoutSTMAction (TimerRegisterDelay var) = writeTVar var True
805887 -- Note that 'threadDelay' is not handled via STM style wakeup, but rather
806888 -- it's handled directly above with 'wakeupThreadDelay' and 'unblockThreads'
807- timeoutSTMAction (TimerThreadDelay _) = return ()
889+ timeoutSTMAction TimerThreadDelay {} = return ()
890+ timeoutSTMAction TimerTimeout {} = return ()
808891
809892unblockThreads :: [ThreadId ] -> SimState s a -> ([ThreadId ], SimState s a )
810893unblockThreads ! wakeup ! simstate@ SimState {runqueue, threads} =
@@ -825,7 +908,76 @@ unblockThreads !wakeup !simstate@SimState {runqueue, threads} =
825908 -- and in which case we mark them as now running
826909 ! threads' = List. foldl'
827910 (flip (Map. adjust (\ t -> t { threadBlocked = False })))
828- threads unblocked
911+ threads
912+ unblocked
913+
914+ -- | This function receives a list of TimerTimeout values that represent threads
915+ -- for which the timeout expired and kills the running thread if needed.
916+ --
917+ -- This function is responsible for the second part of the race condition issue
918+ -- and relates to the 'schedule's 'TimeoutFrame' locking explanation (here is
919+ -- where the assassin threads are launched. So, as explained previously, at this
920+ -- point in code, the timeout expired so we need to interrupt the running
921+ -- thread. If the running thread finished at the same time the timeout expired
922+ -- we have a race condition. To deal with this race condition what we do is
923+ -- look at the lock value. If it is 'Locked' this means that the running thread
924+ -- already finished (or won the race) so we can safely do nothing. Otherwise, if
925+ -- the lock value is 'NotLocked' we need to acquire the lock and launch an
926+ -- assassin thread that is going to interrupt the running one. Note that we
927+ -- should run this interrupting thread in an unmasked state since it might
928+ -- receive a 'ThreadKilled' exception.
929+ --
930+ forkTimeoutInterruptThreads :: [(ThreadId , TimeoutId , STRef s IsLocked , IsLocked )]
931+ -> SimState s a
932+ -> ST s (SimState s a )
933+ forkTimeoutInterruptThreads timeoutExpired simState@ SimState {threads} =
934+ foldlM (\ st@ SimState { runqueue = runqueue,
935+ threads = threads'
936+ }
937+ (t, isLockedRef)
938+ -> do
939+ let tid' = threadId t
940+ threads'' = Map. insert tid' t threads'
941+ runqueue' = Deque. snoc tid' runqueue
942+
943+ writeSTRef isLockedRef (Locked tid')
944+
945+ return st { runqueue = runqueue',
946+ threads = threads''
947+ })
948+ simState
949+ throwToThread
950+
951+ where
952+ -- can only throw exception if the thread exists and if the mutually
953+ -- exclusive lock exists and is still 'NotLocked'
954+ toThrow = [ (tid, tmid, ref, t)
955+ | (tid, tmid, ref, locked) <- timeoutExpired
956+ , Just t <- [Map. lookup tid threads]
957+ , NotLocked <- [locked]
958+ ]
959+ -- we launch a thread responsible for throwing an AsyncCancelled exception
960+ -- to the thread which timeout expired
961+ throwToThread =
962+ [ let nextId = threadNextTId t
963+ tid' = childThreadId tid nextId
964+ in ( Thread { threadId = tid',
965+ threadControl =
966+ ThreadControl
967+ (ThrowTo (toException (TimeoutException tmid))
968+ tid
969+ (Return () ))
970+ ForkFrame ,
971+ threadBlocked = False ,
972+ threadMasking = Unmasked ,
973+ threadThrowTo = [] ,
974+ threadClockId = threadClockId t,
975+ threadLabel = Just " timeout-forked-thread" ,
976+ threadNextTId = 1
977+ }
978+ , ref )
979+ | (tid, tmid, ref, t) <- toThrow
980+ ]
829981
830982
831983-- | Iterate through the control stack to find an enclosing exception handler
@@ -843,7 +995,8 @@ unwindControlStack e thread =
843995 ThreadControl _ ctl -> unwind (threadMasking thread) ctl
844996 where
845997 unwind :: forall s' c . MaskingState
846- -> ControlStack s' c a -> Either Bool (Thread s' a )
998+ -> ControlStack s' c a
999+ -> Either Bool (Thread s' a )
8471000 unwind _ MainFrame = Left True
8481001 unwind _ ForkFrame = Left False
8491002 unwind _ (MaskFrame _k maskst' ctl) = unwind maskst' ctl
@@ -855,12 +1008,28 @@ unwindControlStack e thread =
8551008
8561009 -- Ok! We will be able to continue the thread with the handler
8571010 -- followed by the continuation after the catch
858- Just e' -> Right thread {
859- -- As per async exception rules, the handler is run masked
1011+ Just e' -> Right ( thread {
1012+ -- As per async exception rules, the handler is run
1013+ -- masked
8601014 threadControl = ThreadControl (handler e')
8611015 (MaskFrame k maskst ctl),
8621016 threadMasking = atLeastInterruptibleMask maskst
8631017 }
1018+ )
1019+
1020+ -- Either Timeout fired or the action threw an exception.
1021+ -- - If Timeout fired, then it was possibly during this thread's execution
1022+ -- so we need to run the continuation with a Nothing value.
1023+ -- - If the timeout action threw an exception we need to keep unwinding the
1024+ -- control stack looking for a handler to this exception.
1025+ unwind maskst (TimeoutFrame tmid _ k ctl) =
1026+ case fromException e of
1027+ -- Exception came from timeout expiring
1028+ Just (TimeoutException tmid') ->
1029+ assert (tmid == tmid')
1030+ Right thread { threadControl = ThreadControl (k Nothing ) ctl }
1031+ -- Exception came from a different exception
1032+ _ -> unwind maskst ctl
8641033
8651034 atLeastInterruptibleMask :: MaskingState -> MaskingState
8661035 atLeastInterruptibleMask Unmasked = MaskedInterruptible
0 commit comments