diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index 19ea178f6bf1f9..d7eead54192d55 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -238,7 +238,7 @@ private static void TransparentAwait(object o) private interface IRuntimeAsyncTaskOps { static abstract Action GetContinuationAction(T task); - static abstract Continuation GetContinuationState(T task); + static abstract Continuation MoveContinuationState(T task); static abstract void SetContinuationState(T task, Continuation value); static abstract bool SetCompleted(T task); static abstract void PostToSyncContext(T task, SynchronizationContext syncCtx); @@ -297,9 +297,16 @@ void ITaskCompletionAction.Invoke(Task completingTask) private struct Ops : IRuntimeAsyncTaskOps> { public static Action GetContinuationAction(RuntimeAsyncTask task) => (Action)task.m_action!; - public static Continuation GetContinuationState(RuntimeAsyncTask task) => (Continuation)task.m_stateObject!; + public static Continuation MoveContinuationState(RuntimeAsyncTask task) + { + Continuation continuation = (Continuation)task.m_stateObject!; + task.m_stateObject = null; + return continuation; + } + public static void SetContinuationState(RuntimeAsyncTask task, Continuation value) { + Debug.Assert(task.m_stateObject == null); task.m_stateObject = value; } @@ -373,9 +380,16 @@ void ITaskCompletionAction.Invoke(Task completingTask) private struct Ops : IRuntimeAsyncTaskOps { public static Action GetContinuationAction(RuntimeAsyncTask task) => (Action)task.m_action!; - public static Continuation GetContinuationState(RuntimeAsyncTask task) => (Continuation)task.m_stateObject!; + public static Continuation MoveContinuationState(RuntimeAsyncTask task) + { + Continuation continuation = (Continuation)task.m_stateObject!; + task.m_stateObject = null; + return continuation; + } + public static void SetContinuationState(RuntimeAsyncTask task, Continuation value) { + Debug.Assert(task.m_stateObject == null); task.m_stateObject = value; } @@ -429,7 +443,7 @@ public static unsafe void DispatchContinuations(T task) where T : Task, DispatcherInfo dispatcherInfo; dispatcherInfo.Next = t_dispatcherInfo; - dispatcherInfo.NextContinuation = TOps.GetContinuationState(task); + dispatcherInfo.NextContinuation = TOps.MoveContinuationState(task); t_dispatcherInfo = &dispatcherInfo; while (true)