Skip to content

Commit dfc6b56

Browse files
committed
better sort out bound vs immediate start
1 parent c070a10 commit dfc6b56

File tree

2 files changed

+79
-70
lines changed

2 files changed

+79
-70
lines changed

src/Compiler/Utilities/Async2.fs

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ open System
44
open System.Threading
55
open System.Threading.Tasks
66
open System.Runtime.CompilerServices
7+
open System.Runtime.ExceptionServices
78
open FSharp.Core.CompilerServices.StateMachineHelpers
89
open Microsoft.FSharp.Core.CompilerServices
910

@@ -15,17 +16,12 @@ type IAsync2Invocation<'t> =
1516

1617
and Async2<'t> =
1718
abstract StartImmediate: CancellationToken -> IAsync2Invocation<'t>
19+
abstract StartBound: CancellationToken -> TaskAwaiter<'t>
1820
abstract TailCall: CancellationToken * TaskCompletionSource<'t> voption -> unit
1921
abstract GetAwaiter: unit -> TaskAwaiter<'t>
2022

2123
module Async2Implementation =
2224

23-
open System.Runtime.ExceptionServices
24-
25-
let failIfNot condition message =
26-
if not condition then
27-
failwith message
28-
2925
/// A structure that looks like an Awaiter
3026
type Awaiter<'Awaiter, 'TResult
3127
when 'Awaiter :> ICriticalNotifyCompletion
@@ -54,22 +50,30 @@ module Async2Implementation =
5450
| Bounce of DynamicState
5551
| Immediate of DynamicState
5652

53+
module BindContext =
54+
let bindCount = new ThreadLocal<int>()
55+
56+
[<Literal>]
57+
let bindLimit = 100
58+
59+
let IncrementBindCount () =
60+
bindCount.Value <- bindCount.Value + 1
61+
bindCount.Value >= bindLimit
62+
63+
let Reset () = bindCount.Value <- 0
64+
5765
type Trampoline private () =
5866

5967
let ownerThreadId = Thread.CurrentThread.ManagedThreadId
6068

6169
static let holder = new ThreadLocal<_>(fun () -> Trampoline())
6270

63-
[<Literal>]
64-
static let bindLimit = 100
65-
66-
let mutable bindCount = 0
67-
6871
let mutable pending: Action voption = ValueNone
6972
let mutable running = false
7073

7174
let start (action: Action) =
7275
try
76+
BindContext.Reset()
7377
running <- true
7478
action.Invoke()
7579

@@ -81,10 +85,10 @@ module Async2Implementation =
8185
running <- false
8286

8387
let set action =
84-
failIfNot (Thread.CurrentThread.ManagedThreadId = ownerThreadId) "Trampoline used from wrong thread"
85-
failIfNot pending.IsNone "Trampoline used while already pending"
88+
assert (Thread.CurrentThread.ManagedThreadId = ownerThreadId) // "Trampoline used from wrong thread"
89+
assert pending.IsNone // "Trampoline set while already pending"
8690

87-
bindCount <- 0
91+
BindContext.Reset()
8892

8993
if running then
9094
pending <- ValueSome action
@@ -97,14 +101,6 @@ module Async2Implementation =
97101

98102
member this.Ref: ICriticalNotifyCompletion ref = ref this
99103

100-
member this.Set action = set action
101-
102-
member this.Reset() = bindCount <- 0
103-
104-
member _.IncrementBindCount() =
105-
bindCount <- bindCount + 1
106-
bindCount >= bindLimit
107-
108104
static member Current = holder.Value
109105

110106
module ExceptionCache =
@@ -148,6 +144,9 @@ module Async2Implementation =
148144
[<DefaultValue(false)>]
149145
val mutable CancellationToken: CancellationToken
150146

147+
[<DefaultValue(false)>]
148+
val mutable IsBound: bool
149+
151150
type Async2StateMachine<'TOverall> = ResumableStateMachine<Async2Data<'TOverall>>
152151
type IAsync2StateMachine<'TOverall> = IResumableStateMachine<Async2Data<'TOverall>>
153152
type Async2ResumptionFunc<'TOverall> = ResumptionFunc<Async2Data<'TOverall>>
@@ -160,11 +159,12 @@ module Async2Implementation =
160159
[<DefaultValue(false)>]
161160
val mutable StateMachine: 'm
162161

163-
member ts.Start(ct, tc) =
162+
member ts.Start(ct, tailCallSource, isBound) =
164163
let mutable copy = ts
165164
let mutable data = Async2Data()
166165
data.CancellationToken <- ct
167-
data.TailCallSource <- tc
166+
data.TailCallSource <- tailCallSource
167+
data.IsBound <- isBound
168168
data.MethodBuilder <- AsyncTaskMethodBuilder<'t>.Create()
169169
copy.StateMachine.Data <- data
170170
copy.StateMachine.Data.MethodBuilder.Start(&copy.StateMachine)
@@ -177,29 +177,39 @@ module Async2Implementation =
177177
ts.StateMachine.Data.MethodBuilder.Task.GetAwaiter()
178178

179179
interface Async2<'t> with
180-
member ts.StartImmediate ct = ts.Start(ct, ValueNone)
181-
member ts.TailCall(ct, tc) = ts.Start(ct, tc) |> ignore
180+
member ts.StartImmediate ct = ts.Start(ct, ValueNone, false)
181+
182+
member ts.StartBound ct =
183+
ts.Start(ct, ValueNone, true).GetAwaiter()
184+
185+
member ts.TailCall(ct, tc) = ts.Start(ct, tc, true) |> ignore
182186

183187
member ts.GetAwaiter() =
184-
ts.Start(CancellationToken.None, ValueNone).GetAwaiter()
188+
ts.Start(CancellationToken.None, ValueNone, true).GetAwaiter()
185189

186-
type Async2Dynamic<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>>(getCopy: unit -> 'm) =
187-
member ts.GetCopy() =
188-
Async2(StateMachine = getCopy ()) :> Async2<_>
190+
type Async2Dynamic<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>>(getCopy: bool -> 'm) =
191+
member ts.GetCopy isBound =
192+
Async2(StateMachine = getCopy isBound) :> Async2<_>
189193

190194
interface Async2<'t> with
191-
member ts.StartImmediate ct = ts.GetCopy().StartImmediate(ct)
192-
member ts.TailCall(ct, tc) = ts.GetCopy().TailCall(ct, tc) |> ignore
193-
member ts.GetAwaiter() = ts.GetCopy().GetAwaiter()
195+
member ts.StartImmediate ct = ts.GetCopy(false).StartImmediate(ct)
196+
member ts.StartBound ct = ts.GetCopy(true).StartBound(ct)
197+
198+
member ts.TailCall(ct, tc) =
199+
ts.GetCopy(true).TailCall(ct, tc) |> ignore
200+
201+
member ts.GetAwaiter() = ts.GetCopy(true).GetAwaiter()
194202

195203
[<AutoOpen>]
196204
module Async2Code =
197205
let inline filterCancellation ([<InlineIfLambda>] catch: exn -> Async2Code<_, _>) (exn: exn) =
198206
Async2Code(fun sm ->
199-
let ct = sm.Data.CancellationToken
200-
201207
match exn with
202-
| :? OperationCanceledException as oce when ct.IsCancellationRequested || oce.CancellationToken = ct -> raise exn
208+
| :? OperationCanceledException as oce when
209+
sm.Data.CancellationToken.IsCancellationRequested
210+
|| oce.CancellationToken = sm.Data.CancellationToken
211+
->
212+
raise exn
203213
| _ -> (catch exn).Invoke(&sm))
204214

205215
let inline throwIfCancellationRequested (code: Async2Code<_, _>) =
@@ -209,7 +219,7 @@ module Async2Implementation =
209219

210220
let inline yieldOnBindLimit () =
211221
Async2Code(fun sm ->
212-
if Trampoline.Current.IncrementBindCount() then
222+
if BindContext.IncrementBindCount() then
213223
let __stack_yield_fin = ResumableCode.Yield().Invoke(&sm)
214224

215225
if not __stack_yield_fin then
@@ -249,7 +259,6 @@ module Async2Implementation =
249259
else
250260
bindDynamic (&sm, awaiter, continuation))
251261

252-
[<NoEagerConstraintApplication>]
253262
let inline bindCancellable
254263
([<InlineIfLambda>] cancellable, [<InlineIfLambda>] continuation: 'U -> Async2Code<'Data, 'T>)
255264
: Async2Code<'Data, 'T> =
@@ -301,13 +310,15 @@ module Async2Implementation =
301310
let initialResumptionFunc = Async2ResumptionFunc<'T>(fun sm -> code.Invoke &sm)
302311

303312
let maybeBounce state =
304-
if Trampoline.Current.IncrementBindCount() then
313+
if BindContext.IncrementBindCount() then
305314
Bounce state
306315
else
307316
Immediate state
308317

309-
let resumptionInfo () =
310-
{ new Async2ResumptionDynamicInfo<'T>(initialResumptionFunc, ResumptionData = (maybeBounce Running)) with
318+
let resumptionInfo isBound =
319+
let initialState = if isBound then maybeBounce Running else Immediate Running
320+
321+
{ new Async2ResumptionDynamicInfo<'T>(initialResumptionFunc, ResumptionData = initialState) with
311322
member info.MoveNext(sm) =
312323

313324
let getCurrent () =
@@ -352,7 +363,7 @@ module Async2Implementation =
352363
sm.Data.MethodBuilder.SetStateMachine(state)
353364
}
354365

355-
Async2Dynamic<_, _>(fun () -> Async2StateMachine(ResumptionDynamicInfo = resumptionInfo ()))
366+
Async2Dynamic<_, _>(fun isBound -> Async2StateMachine(ResumptionDynamicInfo = resumptionInfo isBound))
356367

357368
member inline _.Run(code: Async2Code<'T, 'T>) : Async2<'T> =
358369
if __useResumableCode then
@@ -362,7 +373,7 @@ module Async2Implementation =
362373
__resumeAt sm.ResumptionPoint
363374
let mutable error = ValueNone
364375

365-
let __stack_go1 = yieldOnBindLimit().Invoke(&sm)
376+
let __stack_go1 = not sm.Data.IsBound || yieldOnBindLimit().Invoke(&sm)
366377

367378
if __stack_go1 then
368379
try
@@ -431,7 +442,7 @@ module HighPriority =
431442

432443
type Async2Builder with
433444
member inline this.Bind(code: Async2<'U>, [<InlineIfLambda>] continuation) : Async2Code<'Data, 'T> =
434-
bindCancellable ((fun ct -> code.StartImmediate(ct).GetAwaiter()), continuation)
445+
bindCancellable (code.StartBound, continuation)
435446

436447
member inline this.ReturnFrom(code: Async2<'T>) : Async2Code<'T, 'T> = this.Bind(code, this.Return)
437448

@@ -446,7 +457,7 @@ module HighPriority =
446457
| ValueSome tcs ->
447458
// We are already in a tail call chain.
448459
let __stack_ct = sm.Data.CancellationToken
449-
Trampoline.Current.Set(fun () -> code.TailCall(__stack_ct, ValueSome tcs))
460+
code.TailCall(__stack_ct, ValueSome tcs)
450461
false // Return false to abandon this state machine and continue on the next one.
451462
)
452463

@@ -459,12 +470,8 @@ module Async2 =
459470

460471
let CheckAndThrowToken = AsyncLocal<CancellationToken>()
461472

462-
let inline start (code: Async2<_>) ct =
473+
let inline start ct (code: Async2<_>) =
463474
CheckAndThrowToken.Value <- ct
464-
// Only bound computations can participate in trampolining, otherwise we risk sync over async deadlocks.
465-
// To prevent this, we reset the bind count here.
466-
// This computation will not initially bounce, even if it is nested inside another async2 computation.
467-
Trampoline.Current.Reset()
468475
code.StartImmediate ct
469476

470477
let run ct (code: Async2<'t>) =
@@ -473,25 +480,22 @@ module Async2 =
473480
isNull SynchronizationContext.Current
474481
&& TaskScheduler.Current = TaskScheduler.Default
475482
then
476-
start code ct |> _.Task.GetAwaiter().GetResult()
483+
start ct code |> _.GetAwaiter().GetResult()
477484
else
478-
Task.Run<'t>(fun () -> start code ct |> _.Task).GetAwaiter().GetResult()
485+
Task.Run<'t>(fun () -> start ct code |> _.Task).GetAwaiter().GetResult()
479486

480487
let runWithoutCancellation code = run CancellationToken.None code
481488

482-
let startAsTaskWithoutCancellation code = start code CancellationToken.None
483-
484-
let startAsTask ct code = start code ct |> _.Task
485-
486-
let queue ct code = Task.Run(fun () -> start code ct)
489+
let startAsTaskWithoutCancellation code =
490+
start CancellationToken.None code |> _.Task
487491

488492
let queueTask ct code =
489-
Task.Run<'t>(fun () -> startAsTask ct code)
493+
Task.Run<'t>(fun () -> start ct code |> _.Task)
490494

491495
let toAsync (code: Async2<'t>) =
492496
async {
493497
let! ct = Async.CancellationToken
494-
let task = startAsTask ct code
498+
let task = start ct code |> _.Task
495499
return! Async.AwaitTask task
496500
}
497501

@@ -517,7 +521,7 @@ type Async2 =
517521

518522
static member StartAsTask(computation: Async2<_>, ?cancellationToken: CancellationToken) : Task<_> =
519523
let ct = defaultArg cancellationToken CancellationToken.None
520-
Async2.startAsTask ct computation
524+
Async2.start ct computation |> _.Task
521525

522526
static member RunImmediate(computation: Async2<'T>, ?cancellationToken: CancellationToken) : 'T =
523527
let ct = defaultArg cancellationToken CancellationToken.None
@@ -567,7 +571,7 @@ type Async2 =
567571
static member TryCancelled(computation: Async2<'T>, compensation) =
568572
async2 {
569573
let! ct = Async2.CancellationToken
570-
let task = computation |> Async2.startAsTask ct
574+
let task = computation |> Async2.start ct |> _.Task
571575

572576
try
573577
return! task

0 commit comments

Comments
 (0)