@@ -4,6 +4,7 @@ open System
44open System.Threading
55open System.Threading .Tasks
66open System.Runtime .CompilerServices
7+ open System.Runtime .ExceptionServices
78open FSharp.Core .CompilerServices .StateMachineHelpers
89open Microsoft.FSharp .Core .CompilerServices
910
@@ -15,17 +16,12 @@ type IAsync2Invocation<'t> =
1516
1617and Async2 < 't > =
1718 abstract StartImmediate: CancellationToken -> IAsync2Invocation < 't >
19+ abstract StartBound: CancellationToken -> TaskAwaiter < 't >
1820 abstract TailCall: CancellationToken * TaskCompletionSource < 't > voption -> unit
1921 abstract GetAwaiter: unit -> TaskAwaiter < 't >
2022
2123module Async2Implementation =
2224
23- open System.Runtime .ExceptionServices
24-
25- let failIfNot condition message =
26- if not condition then
27- failwith message
28-
2925 /// A structure that looks like an Awaiter
3026 type Awaiter < 'Awaiter , 'TResult
3127 when 'Awaiter :> ICriticalNotifyCompletion
@@ -54,22 +50,30 @@ module Async2Implementation =
5450 | Bounce of DynamicState
5551 | Immediate of DynamicState
5652
53+ module BindContext =
54+ let bindCount = new ThreadLocal< int>()
55+
56+ [<Literal>]
57+ let bindLimit = 100
58+
59+ let IncrementBindCount () =
60+ bindCount.Value <- bindCount.Value + 1
61+ bindCount.Value >= bindLimit
62+
63+ let Reset () = bindCount.Value <- 0
64+
5765 type Trampoline private () =
5866
5967 let ownerThreadId = Thread.CurrentThread.ManagedThreadId
6068
6169 static let holder = new ThreadLocal<_>( fun () -> Trampoline())
6270
63- [<Literal>]
64- static let bindLimit = 100
65-
66- let mutable bindCount = 0
67-
6871 let mutable pending : Action voption = ValueNone
6972 let mutable running = false
7073
7174 let start ( action : Action ) =
7275 try
76+ BindContext.Reset()
7377 running <- true
7478 action.Invoke()
7579
@@ -81,10 +85,10 @@ module Async2Implementation =
8185 running <- false
8286
8387 let set action =
84- failIfNot ( Thread.CurrentThread.ManagedThreadId = ownerThreadId) " Trampoline used from wrong thread"
85- failIfNot pending.IsNone " Trampoline used while already pending"
88+ assert ( Thread.CurrentThread.ManagedThreadId = ownerThreadId) // "Trampoline used from wrong thread"
89+ assert pending.IsNone // "Trampoline set while already pending"
8690
87- bindCount <- 0
91+ BindContext.Reset ()
8892
8993 if running then
9094 pending <- ValueSome action
@@ -97,14 +101,6 @@ module Async2Implementation =
97101
98102 member this.Ref : ICriticalNotifyCompletion ref = ref this
99103
100- member this.Set action = set action
101-
102- member this.Reset () = bindCount <- 0
103-
104- member _.IncrementBindCount () =
105- bindCount <- bindCount + 1
106- bindCount >= bindLimit
107-
108104 static member Current = holder.Value
109105
110106 module ExceptionCache =
@@ -148,6 +144,9 @@ module Async2Implementation =
148144 [<DefaultValue( false ) >]
149145 val mutable CancellationToken : CancellationToken
150146
147+ [<DefaultValue( false ) >]
148+ val mutable IsBound : bool
149+
151150 type Async2StateMachine < 'TOverall > = ResumableStateMachine< Async2Data< 'TOverall>>
152151 type IAsync2StateMachine < 'TOverall > = IResumableStateMachine< Async2Data< 'TOverall>>
153152 type Async2ResumptionFunc < 'TOverall > = ResumptionFunc< Async2Data< 'TOverall>>
@@ -160,11 +159,12 @@ module Async2Implementation =
160159 [<DefaultValue( false ) >]
161160 val mutable StateMachine : 'm
162161
163- member ts.Start ( ct , tc ) =
162+ member ts.Start ( ct , tailCallSource , isBound ) =
164163 let mutable copy = ts
165164 let mutable data = Async2Data()
166165 data.CancellationToken <- ct
167- data.TailCallSource <- tc
166+ data.TailCallSource <- tailCallSource
167+ data.IsBound <- isBound
168168 data.MethodBuilder <- AsyncTaskMethodBuilder< 't>. Create()
169169 copy.StateMachine.Data <- data
170170 copy.StateMachine.Data.MethodBuilder.Start(& copy.StateMachine)
@@ -177,29 +177,39 @@ module Async2Implementation =
177177 ts.StateMachine.Data.MethodBuilder.Task.GetAwaiter()
178178
179179 interface Async2< 't> with
180- member ts.StartImmediate ct = ts.Start( ct, ValueNone)
181- member ts.TailCall ( ct , tc ) = ts.Start( ct, tc) |> ignore
180+ member ts.StartImmediate ct = ts.Start( ct, ValueNone, false )
181+
182+ member ts.StartBound ct =
183+ ts.Start( ct, ValueNone, true ) .GetAwaiter()
184+
185+ member ts.TailCall ( ct , tc ) = ts.Start( ct, tc, true ) |> ignore
182186
183187 member ts.GetAwaiter () =
184- ts.Start( CancellationToken.None, ValueNone) .GetAwaiter()
188+ ts.Start( CancellationToken.None, ValueNone, true ) .GetAwaiter()
185189
186- type Async2Dynamic < 't , 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine < 't >>( getCopy : unit -> 'm ) =
187- member ts.GetCopy () =
188- Async2( StateMachine = getCopy () ) :> Async2<_>
190+ type Async2Dynamic < 't , 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine < 't >>( getCopy : bool -> 'm ) =
191+ member ts.GetCopy isBound =
192+ Async2( StateMachine = getCopy isBound ) :> Async2<_>
189193
190194 interface Async2< 't> with
191- member ts.StartImmediate ct = ts.GetCopy() .StartImmediate( ct)
192- member ts.TailCall ( ct , tc ) = ts.GetCopy() .TailCall( ct, tc) |> ignore
193- member ts.GetAwaiter () = ts.GetCopy() .GetAwaiter()
195+ member ts.StartImmediate ct = ts.GetCopy( false ) .StartImmediate( ct)
196+ member ts.StartBound ct = ts.GetCopy( true ) .StartBound( ct)
197+
198+ member ts.TailCall ( ct , tc ) =
199+ ts.GetCopy( true ) .TailCall( ct, tc) |> ignore
200+
201+ member ts.GetAwaiter () = ts.GetCopy( true ) .GetAwaiter()
194202
195203 [<AutoOpen>]
196204 module Async2Code =
197205 let inline filterCancellation ( [<InlineIfLambda>] catch : exn -> Async2Code < _ , _ >) ( exn : exn ) =
198206 Async2Code( fun sm ->
199- let ct = sm.Data.CancellationToken
200-
201207 match exn with
202- | :? OperationCanceledException as oce when ct.IsCancellationRequested || oce.CancellationToken = ct -> raise exn
208+ | :? OperationCanceledException as oce when
209+ sm.Data.CancellationToken.IsCancellationRequested
210+ || oce.CancellationToken = sm.Data.CancellationToken
211+ ->
212+ raise exn
203213 | _ -> ( catch exn) .Invoke(& sm))
204214
205215 let inline throwIfCancellationRequested ( code : Async2Code < _ , _ >) =
@@ -209,7 +219,7 @@ module Async2Implementation =
209219
210220 let inline yieldOnBindLimit () =
211221 Async2Code( fun sm ->
212- if Trampoline.Current .IncrementBindCount() then
222+ if BindContext .IncrementBindCount() then
213223 let __stack_yield_fin = ResumableCode.Yield() .Invoke(& sm)
214224
215225 if not __ stack_ yield_ fin then
@@ -249,7 +259,6 @@ module Async2Implementation =
249259 else
250260 bindDynamic (& sm, awaiter, continuation))
251261
252- [<NoEagerConstraintApplication>]
253262 let inline bindCancellable
254263 ( [<InlineIfLambda>] cancellable , [<InlineIfLambda>] continuation : 'U -> Async2Code < 'Data , 'T >)
255264 : Async2Code < 'Data , 'T > =
@@ -301,13 +310,15 @@ module Async2Implementation =
301310 let initialResumptionFunc = Async2ResumptionFunc< 'T>( fun sm -> code.Invoke & sm)
302311
303312 let maybeBounce state =
304- if Trampoline.Current .IncrementBindCount() then
313+ if BindContext .IncrementBindCount() then
305314 Bounce state
306315 else
307316 Immediate state
308317
309- let resumptionInfo () =
310- { new Async2ResumptionDynamicInfo< 'T>( initialResumptionFunc, ResumptionData = ( maybeBounce Running)) with
318+ let resumptionInfo isBound =
319+ let initialState = if isBound then maybeBounce Running else Immediate Running
320+
321+ { new Async2ResumptionDynamicInfo< 'T>( initialResumptionFunc, ResumptionData = initialState) with
311322 member info.MoveNext ( sm ) =
312323
313324 let getCurrent () =
@@ -352,7 +363,7 @@ module Async2Implementation =
352363 sm.Data.MethodBuilder.SetStateMachine( state)
353364 }
354365
355- Async2Dynamic<_, _>( fun () -> Async2StateMachine( ResumptionDynamicInfo = resumptionInfo () ))
366+ Async2Dynamic<_, _>( fun isBound -> Async2StateMachine( ResumptionDynamicInfo = resumptionInfo isBound ))
356367
357368 member inline _.Run ( code : Async2Code < 'T , 'T >) : Async2 < 'T > =
358369 if __ useResumableCode then
@@ -362,7 +373,7 @@ module Async2Implementation =
362373 __ resumeAt sm.ResumptionPoint
363374 let mutable error = ValueNone
364375
365- let __stack_go1 = yieldOnBindLimit() .Invoke(& sm)
376+ let __stack_go1 = not sm.Data.IsBound || yieldOnBindLimit() .Invoke(& sm)
366377
367378 if __ stack_ go1 then
368379 try
@@ -431,7 +442,7 @@ module HighPriority =
431442
432443 type Async2Builder with
433444 member inline this.Bind ( code : Async2 < 'U >, [<InlineIfLambda>] continuation ) : Async2Code < 'Data , 'T > =
434- bindCancellable (( fun ct -> code.StartImmediate ( ct ) .GetAwaiter ()) , continuation)
445+ bindCancellable ( code.StartBound , continuation)
435446
436447 member inline this.ReturnFrom ( code : Async2 < 'T >) : Async2Code < 'T , 'T > = this.Bind( code, this.Return)
437448
@@ -446,7 +457,7 @@ module HighPriority =
446457 | ValueSome tcs ->
447458 // We are already in a tail call chain.
448459 let __stack_ct = sm.Data.CancellationToken
449- Trampoline.Current.Set ( fun () -> code.TailCall(__ stack_ ct, ValueSome tcs) )
460+ code.TailCall(__ stack_ ct, ValueSome tcs)
450461 false // Return false to abandon this state machine and continue on the next one.
451462 )
452463
@@ -459,12 +470,8 @@ module Async2 =
459470
460471 let CheckAndThrowToken = AsyncLocal< CancellationToken>()
461472
462- let inline start ( code : Async2 < _ >) ct =
473+ let inline start ct ( code : Async2 < _ >) =
463474 CheckAndThrowToken.Value <- ct
464- // Only bound computations can participate in trampolining, otherwise we risk sync over async deadlocks.
465- // To prevent this, we reset the bind count here.
466- // This computation will not initially bounce, even if it is nested inside another async2 computation.
467- Trampoline.Current.Reset()
468475 code.StartImmediate ct
469476
470477 let run ct ( code : Async2 < 't >) =
@@ -473,25 +480,22 @@ module Async2 =
473480 isNull SynchronizationContext.Current
474481 && TaskScheduler.Current = TaskScheduler.Default
475482 then
476- start code ct |> _. Task .GetAwaiter() .GetResult()
483+ start ct code |> _. GetAwaiter() .GetResult()
477484 else
478- Task.Run< 't>( fun () -> start code ct |> _. Task) .GetAwaiter() .GetResult()
485+ Task.Run< 't>( fun () -> start ct code |> _. Task) .GetAwaiter() .GetResult()
479486
480487 let runWithoutCancellation code = run CancellationToken.None code
481488
482- let startAsTaskWithoutCancellation code = start code CancellationToken.None
483-
484- let startAsTask ct code = start code ct |> _. Task
485-
486- let queue ct code = Task.Run( fun () -> start code ct)
489+ let startAsTaskWithoutCancellation code =
490+ start CancellationToken.None code |> _. Task
487491
488492 let queueTask ct code =
489- Task.Run< 't>( fun () -> startAsTask ct code)
493+ Task.Run< 't>( fun () -> start ct code |> _. Task )
490494
491495 let toAsync ( code : Async2 < 't >) =
492496 async {
493497 let! ct = Async.CancellationToken
494- let task = startAsTask ct code
498+ let task = start ct code |> _. Task
495499 return ! Async.AwaitTask task
496500 }
497501
@@ -517,7 +521,7 @@ type Async2 =
517521
518522 static member StartAsTask ( computation : Async2 < _ >, ? cancellationToken : CancellationToken ) : Task < _ > =
519523 let ct = defaultArg cancellationToken CancellationToken.None
520- Async2.startAsTask ct computation
524+ Async2.start ct computation |> _. Task
521525
522526 static member RunImmediate ( computation : Async2 < 'T >, ? cancellationToken : CancellationToken ) : 'T =
523527 let ct = defaultArg cancellationToken CancellationToken.None
@@ -567,7 +571,7 @@ type Async2 =
567571 static member TryCancelled ( computation : Async2 < 'T >, compensation ) =
568572 async2 {
569573 let! ct = Async2.CancellationToken
570- let task = computation |> Async2.startAsTask ct
574+ let task = computation |> Async2.start ct |> _. Task
571575
572576 try
573577 return ! task
0 commit comments