Skip to content

Commit ef15724

Browse files
committed
sort out overloads again
1 parent 9c7ecb2 commit ef15724

File tree

1 file changed

+110
-119
lines changed

1 file changed

+110
-119
lines changed

src/Compiler/Utilities/Async2.fs

Lines changed: 110 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ open System
44
open System.Threading
55
open System.Threading.Tasks
66
open System.Runtime.CompilerServices
7+
open FSharp.Core.CompilerServices.StateMachineHelpers
8+
open Microsoft.FSharp.Core.CompilerServices
79

810
#nowarn 3513
911

1012
type IAsync2Invocation<'t> =
1113
abstract Task: Task<'t>
14+
abstract GetAwaiter: unit -> TaskAwaiter<'t>
1215

1316
and Async2<'t> =
1417
abstract StartImmediate: CancellationToken -> IAsync2Invocation<'t>
@@ -17,9 +20,6 @@ and Async2<'t> =
1720

1821
module Async2Implementation =
1922

20-
open FSharp.Core.CompilerServices.StateMachineHelpers
21-
22-
open Microsoft.FSharp.Core.CompilerServices
2323
open System.Runtime.ExceptionServices
2424

2525
let failIfNot condition message =
@@ -35,15 +35,16 @@ module Async2Implementation =
3535
type Awaitable<'Awaitable, 'Awaiter, 'TResult when 'Awaitable: (member GetAwaiter: unit -> Awaiter<'Awaiter, 'TResult>)> = 'Awaitable
3636

3737
module Awaiter =
38-
let inline isCompleted (awaiter: ^Awaiter) : bool when ^Awaiter: (member get_IsCompleted: unit -> bool) = awaiter.get_IsCompleted ()
38+
let inline isCompleted (awaiter: Awaiter<_, _>) = awaiter.get_IsCompleted ()
3939

40-
let inline getResult (awaiter: ^Awaiter) : ^TResult when ^Awaiter: (member GetResult: unit -> ^TResult) = awaiter.GetResult()
40+
let inline getResult (awaiter: Awaiter<_, _>) = awaiter.GetResult()
4141

42-
let inline onCompleted (awaiter: ^Awaiter) (continuation: Action) : unit when ^Awaiter :> INotifyCompletion =
43-
awaiter.OnCompleted continuation
42+
let inline onCompleted (awaiter: Awaiter<_, _>) continuation = awaiter.OnCompleted continuation
4443

45-
let inline unsafeOnCompleted (awaiter: ^Awaiter) (continuation: Action) : unit when ^Awaiter :> ICriticalNotifyCompletion =
46-
awaiter.UnsafeOnCompleted continuation
44+
let inline unsafeOnCompleted (awaiter: Awaiter<_, _>) continuation = awaiter.UnsafeOnCompleted continuation
45+
46+
module Awaitable =
47+
let inline getAwaiter (awaitable: Awaitable<_, _, _>) = awaitable.GetAwaiter()
4748

4849
type DynamicState =
4950
| Running
@@ -91,8 +92,8 @@ module Async2Implementation =
9192
start action
9293

9394
interface ICriticalNotifyCompletion with
94-
member _.OnCompleted(continuation) = set continuation
95-
member _.UnsafeOnCompleted(continuation) = set continuation
95+
member _.OnCompleted continuation = set continuation
96+
member _.UnsafeOnCompleted continuation = set continuation
9697

9798
member this.Ref: ICriticalNotifyCompletion ref = ref this
9899

@@ -155,7 +156,7 @@ module Async2Implementation =
155156
type Async2Code<'TOverall, 'T> = ResumableCode<Async2Data<'TOverall>, 'T>
156157

157158
[<Struct; NoComparison>]
158-
type Async2Impl<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>> =
159+
type Async2<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>> =
159160
[<DefaultValue(false)>]
160161
val mutable StateMachine: 'm
161162

@@ -172,26 +173,24 @@ module Async2Implementation =
172173
interface IAsync2Invocation<'t> with
173174
member ts.Task = ts.StateMachine.Data.MethodBuilder.Task
174175

176+
member ts.GetAwaiter() =
177+
ts.StateMachine.Data.MethodBuilder.Task.GetAwaiter()
178+
175179
interface Async2<'t> with
176180
member ts.StartImmediate ct = ts.Start(ct, ValueNone)
177181
member ts.TailCall(ct, tc) = ts.Start(ct, tc) |> ignore
178182

179183
member ts.GetAwaiter() =
180-
ts.Start(CancellationToken.None, ValueNone).Task.GetAwaiter()
181-
182-
[<NoComparison>]
183-
type Async2ImplDynamic<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>>(getCopy: unit -> 'm) =
184+
ts.Start(CancellationToken.None, ValueNone).GetAwaiter()
184185

185-
member ts.Start(ct, tc) =
186-
let mutable copy = Async2Impl(StateMachine = getCopy ())
187-
copy.Start(ct, tc)
186+
type Async2Dynamic<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>>(getCopy: unit -> 'm) =
187+
member ts.GetCopy() =
188+
Async2(StateMachine = getCopy ()) :> Async2<_>
188189

189190
interface Async2<'t> with
190-
member ts.StartImmediate ct = ts.Start(ct, ValueNone)
191-
member ts.TailCall(ct, tc) = ts.Start(ct, tc) |> ignore
192-
193-
member ts.GetAwaiter() =
194-
ts.Start(CancellationToken.None, ValueNone).Task.GetAwaiter()
191+
member ts.StartImmediate ct = ts.GetCopy().StartImmediate(ct)
192+
member ts.TailCall(ct, tc) = ts.GetCopy().TailCall(ct, tc) |> ignore
193+
member ts.GetAwaiter() = ts.GetCopy().GetAwaiter()
195194

196195
[<AutoOpen>]
197196
module Async2Code =
@@ -209,7 +208,7 @@ module Async2Implementation =
209208
code.Invoke(&sm))
210209

211210
let inline yieldOnBindLimit () =
212-
Async2Code<_, _>(fun sm ->
211+
Async2Code(fun sm ->
213212
if Trampoline.Current.IncrementBindCount() then
214213
let __stack_yield_fin = ResumableCode.Yield().Invoke(&sm)
215214

@@ -220,7 +219,41 @@ module Async2Implementation =
220219
else
221220
true)
222221

223-
type CancellableAwaiter<'t, 'a when Awaiter<'a, 't>> = CancellationToken -> 'a
222+
let inline bindDynamic (sm: byref<Async2StateMachine<_>>, awaiter, [<InlineIfLambda>] continuation: _ -> Async2Code<_, _>) =
223+
if Awaiter.isCompleted awaiter then
224+
(Awaiter.getResult awaiter |> continuation).Invoke(&sm)
225+
else
226+
let resumptionFunc =
227+
Async2ResumptionFunc(fun sm ->
228+
let result = ExceptionCache.GetResultOrThrow awaiter
229+
(continuation result).Invoke(&sm))
230+
231+
sm.ResumptionDynamicInfo.ResumptionFunc <- resumptionFunc
232+
sm.ResumptionDynamicInfo.ResumptionData <- Awaiting awaiter
233+
false
234+
235+
let inline bindAwaiter (awaiter, [<InlineIfLambda>] continuation: 'U -> Async2Code<'Data, 'T>) : Async2Code<'Data, 'T> =
236+
Async2Code(fun sm ->
237+
if __useResumableCode then
238+
if Awaiter.isCompleted awaiter then
239+
continuation(ExceptionCache.GetResultOrThrow awaiter).Invoke(&sm)
240+
else
241+
let __stack_yield_fin = ResumableCode.Yield().Invoke(&sm)
242+
243+
if __stack_yield_fin then
244+
continuation(ExceptionCache.GetResultOrThrow awaiter).Invoke(&sm)
245+
else
246+
let mutable __stack_awaiter = awaiter
247+
sm.Data.MethodBuilder.AwaitUnsafeOnCompleted(&__stack_awaiter, &sm)
248+
false
249+
else
250+
bindDynamic (&sm, awaiter, continuation))
251+
252+
[<NoEagerConstraintApplication>]
253+
let inline bindCancellable
254+
([<InlineIfLambda>] cancellable, [<InlineIfLambda>] continuation: 'U -> Async2Code<'Data, 'T>)
255+
: Async2Code<'Data, 'T> =
256+
Async2Code<'Data, 'T>(fun sm -> bindAwaiter(cancellable sm.Data.CancellationToken, continuation).Invoke(&sm))
224257

225258
type Async2Builder() =
226259

@@ -264,81 +297,6 @@ module Async2Implementation =
264297
member inline _.For(sequence: seq<'T>, [<InlineIfLambda>] body: 'T -> Async2Code<'TOverall, unit>) : Async2Code<'TOverall, unit> =
265298
ResumableCode.For(sequence, fun x -> body x |> throwIfCancellationRequested)
266299

267-
static member inline BindDynamic
268-
(sm: byref<Async2StateMachine<_>>, awaiter, [<InlineIfLambda>] continuation: _ -> Async2Code<_, _>)
269-
=
270-
if Awaiter.isCompleted awaiter then
271-
(Awaiter.getResult awaiter |> continuation).Invoke(&sm)
272-
else
273-
let resumptionFunc =
274-
Async2ResumptionFunc(fun sm ->
275-
let result = ExceptionCache.GetResultOrThrow awaiter
276-
(continuation result).Invoke(&sm))
277-
278-
sm.ResumptionDynamicInfo.ResumptionFunc <- resumptionFunc
279-
sm.ResumptionDynamicInfo.ResumptionData <- Awaiting awaiter
280-
false
281-
282-
member inline _.BindAwaiter
283-
(awaiter: Awaiter<_, _>, [<InlineIfLambda>] continuation: 'U -> Async2Code<'Data, 'T>)
284-
: Async2Code<'Data, 'T> =
285-
Async2Code(fun sm ->
286-
if __useResumableCode then
287-
if Awaiter.isCompleted awaiter then
288-
continuation(ExceptionCache.GetResultOrThrow awaiter).Invoke(&sm)
289-
else
290-
let __stack_yield_fin = ResumableCode.Yield().Invoke(&sm)
291-
292-
if __stack_yield_fin then
293-
continuation(ExceptionCache.GetResultOrThrow awaiter).Invoke(&sm)
294-
else
295-
let mutable __stack_awaiter = awaiter
296-
sm.Data.MethodBuilder.AwaitUnsafeOnCompleted(&__stack_awaiter, &sm)
297-
false
298-
else
299-
Async2Builder.BindDynamic(&sm, awaiter, continuation))
300-
301-
member inline this.BindCancellable
302-
([<InlineIfLambda>] binding: CancellableAwaiter<'U, 'Awaiter>, [<InlineIfLambda>] continuation: 'U -> Async2Code<'Data, 'T>)
303-
: Async2Code<'Data, 'T> =
304-
Async2Code(fun sm -> this.BindAwaiter(binding sm.Data.CancellationToken, continuation).Invoke(&sm))
305-
306-
member inline this.Bind(code: Async2<'U>, [<InlineIfLambda>] continuation: 'U -> Async2Code<'Data, 'T>) : Async2Code<'Data, 'T> =
307-
Async2Code(fun sm -> this.BindCancellable((fun ct -> code.StartImmediate(ct).Task.GetAwaiter()), continuation).Invoke(&sm))
308-
309-
member inline this.Bind(awaiter, [<InlineIfLambda>] continuation) = this.BindAwaiter(awaiter, continuation)
310-
311-
member inline this.Bind(cancellable, [<InlineIfLambda>] continuation) =
312-
this.BindCancellable(cancellable, continuation)
313-
314-
member inline this.ReturnFrom(code: Async2<'T>) : Async2Code<'T, 'T> = this.Bind(code, this.Return)
315-
316-
member inline this.ReturnFrom(awaiter) = this.BindAwaiter(awaiter, this.Return)
317-
318-
member inline this.ReturnFrom(cancellable) =
319-
this.BindCancellable(cancellable, this.Return)
320-
321-
member inline this.ReturnFromFinal(code: Async2<'T>) =
322-
Async2Code(fun sm ->
323-
let __stack_ct = sm.Data.CancellationToken
324-
325-
match sm.Data.TailCallSource with
326-
| ValueNone ->
327-
// This is the start of a tail call chain. we need to return here when the entire chain is done.
328-
let __stack_tcs = TaskCompletionSource<_>()
329-
code.TailCall(__stack_ct, ValueSome __stack_tcs)
330-
this.BindAwaiter(__stack_tcs.Task.GetAwaiter(), this.Return).Invoke(&sm)
331-
| ValueSome tcs ->
332-
// We are already in a tail call chain.
333-
Trampoline.Current.Set(fun () -> code.TailCall(__stack_ct, ValueSome tcs))
334-
false // Return false to abandon this state machine and continue on the next one.
335-
)
336-
337-
member inline this.ReturnFromFinal(awaiter) : Async2Code<'T, 'T> = this.BindAwaiter(awaiter, this.Return)
338-
339-
member inline this.ReturnFromFinal(cancellable) : Async2Code<'T, 'T> =
340-
this.BindCancellable(cancellable, this.Return)
341-
342300
static member inline RunDynamic(code: Async2Code<'T, 'T>) : Async2<'T> =
343301
let initialResumptionFunc = Async2ResumptionFunc<'T>(fun sm -> code.Invoke &sm)
344302

@@ -394,7 +352,7 @@ module Async2Implementation =
394352
sm.Data.MethodBuilder.SetStateMachine(state)
395353
}
396354

397-
Async2ImplDynamic<_, _>(fun () -> Async2StateMachine(ResumptionDynamicInfo = resumptionInfo ()))
355+
Async2Dynamic<_, _>(fun () -> Async2StateMachine(ResumptionDynamicInfo = resumptionInfo ()))
398356

399357
member inline _.Run(code: Async2Code<'T, 'T>) : Async2<'T> =
400358
if __useResumableCode then
@@ -430,39 +388,72 @@ module Async2Implementation =
430388

431389
(SetStateMachineMethodImpl<_>(fun sm state -> sm.Data.MethodBuilder.SetStateMachine state))
432390

433-
(AfterCode<_, _>(fun sm -> Async2Impl<_, _>(StateMachine = sm) :> Async2<'T>))
391+
(AfterCode<_, _>(fun sm -> Async2<_, _>(StateMachine = sm) :> Async2<'T>))
434392
else
435393
Async2Builder.RunDynamic(code)
436394

437-
member inline _.Source(code: Async2<_>) = code
395+
open Async2Implementation
438396

439397
[<AutoOpen>]
440-
module Async2AutoOpens =
441-
open Async2Implementation
398+
module LowPriority =
399+
type Async2Builder with
400+
[<NoEagerConstraintApplication>]
401+
member inline this.Bind(awaitable, [<InlineIfLambda>] continuation) =
402+
bindAwaiter (Awaitable.getAwaiter awaitable, continuation)
442403

443-
let async2 = Async2Builder()
404+
[<NoEagerConstraintApplication>]
405+
member inline this.ReturnFrom(awaitable) = this.Bind(awaitable, this.Return)
444406

445-
[<AutoOpen>]
446-
module Async2LowPriority =
447-
open Async2Implementation
407+
[<NoEagerConstraintApplication>]
408+
member inline this.ReturnFromFinal(awaitable) = this.ReturnFrom(awaitable)
448409

410+
[<AutoOpen>]
411+
module MediumPriority =
449412
type Async2Builder with
450-
member inline _.Source(awaitable: Awaitable<_, _, _>) = awaitable.GetAwaiter()
413+
member inline this.Bind(expr: Async<_>, [<InlineIfLambda>] continuation) =
414+
bindCancellable ((fun ct -> Async.StartAsTask(expr, cancellationToken = ct).GetAwaiter()), continuation)
415+
416+
member inline this.Bind(task: Task, [<InlineIfLambda>] continuation) =
417+
bindAwaiter (task.ConfigureAwait(false).GetAwaiter(), continuation)
451418

452-
member inline _.Source(items: _ seq) : _ seq = upcast items
419+
member inline this.Bind(task: Task<_>, [<InlineIfLambda>] continuation) =
420+
bindAwaiter (task.ConfigureAwait(false).GetAwaiter(), continuation)
421+
422+
member inline this.ReturnFrom(task: Task) = this.Bind(task, this.Return)
423+
member inline this.ReturnFrom(task: Task<_>) = this.Bind(task, this.Return)
424+
member inline this.ReturnFrom(expr: Async<_>) = this.Bind(expr, this.Return)
425+
member inline this.ReturnFromFinal(task: Task) = this.ReturnFrom(task)
426+
member inline this.ReturnFromFinal(task: Task<_>) = this.ReturnFrom(task)
427+
member inline this.ReturnFromFinal(expr: Async<_>) = this.ReturnFrom(expr)
453428

454429
[<AutoOpen>]
455-
module Async2MediumPriority =
456-
open Async2Implementation
430+
module HighPriority =
457431

458432
type Async2Builder with
459-
member inline _.Source(task: Task) = task.ConfigureAwait(false).GetAwaiter()
460-
member inline _.Source(task: Task<_>) = task.ConfigureAwait(false).GetAwaiter()
433+
member inline this.Bind(code: Async2<'U>, [<InlineIfLambda>] continuation) : Async2Code<'Data, 'T> =
434+
bindCancellable ((fun ct -> code.StartImmediate(ct).GetAwaiter()), continuation)
461435

462-
member inline this.Source(expr: Async<'T>) : CancellableAwaiter<_, _> =
463-
fun ct -> Async.StartAsTask(expr, cancellationToken = ct).GetAwaiter()
436+
member inline this.ReturnFrom(code: Async2<'T>) : Async2Code<'T, 'T> = this.Bind(code, this.Return)
464437

465-
open Async2Implementation
438+
member inline this.ReturnFromFinal(code: Async2<'T>) =
439+
Async2Code(fun sm ->
440+
match sm.Data.TailCallSource with
441+
| ValueNone ->
442+
// This is the start of a tail call chain. we need to return here when the entire chain is done.
443+
let __stack_tcs = TaskCompletionSource<_>()
444+
code.TailCall(sm.Data.CancellationToken, ValueSome __stack_tcs)
445+
this.Bind(__stack_tcs.Task, this.Return).Invoke(&sm)
446+
| ValueSome tcs ->
447+
// We are already in a tail call chain.
448+
let __stack_ct = sm.Data.CancellationToken
449+
Trampoline.Current.Set(fun () -> code.TailCall(__stack_ct, ValueSome tcs))
450+
false // Return false to abandon this state machine and continue on the next one.
451+
)
452+
453+
[<AutoOpen>]
454+
module Async2AutoOpens =
455+
456+
let async2 = Async2Builder()
466457

467458
module Async2 =
468459

0 commit comments

Comments
 (0)