Skip to content

Commit 5173bbd

Browse files
committed
Implement ^TaskLike version of Bind to allow more flexible types with do!, like ValueTask and non-generic Task
1 parent 639e2a6 commit 5173bbd

File tree

2 files changed

+131
-72
lines changed

2 files changed

+131
-72
lines changed

src/FSharp.Control.TaskSeq.Test/TaskSeq.Do.Tests.fs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,36 @@ open FSharp.Control
1111
let ``CE taskSeq: use 'do'`` () =
1212
let mutable value = 0
1313

14-
taskSeq { do value <- value + 1 }
15-
16-
|> verifyEmpty
14+
taskSeq { do value <- value + 1 } |> verifyEmpty
1715

1816
[<Fact>]
1917
let ``CE taskSeq: use 'do!' with a task<unit>`` () =
2018
let mutable value = 0
2119

2220
taskSeq { do! task { do value <- value + 1 } }
23-
2421
|> verifyEmpty
22+
|> Task.map (fun _ -> value |> should equal 1)
2523

26-
//[<Fact>]
27-
//let ``CE taskSeq: use 'do!' with a valuetask<unit>`` () =
28-
// let mutable value = 0
29-
30-
// taskSeq { do! ValueTask.ofIValueTaskSource (task { do value <- value + 1 }) }
31-
32-
// |> verifyEmpty
33-
34-
//[<Fact>]
35-
//let ``CE taskSeq: use 'do!' with a non-generic valuetask`` () =
36-
// let mutable value = 0
24+
[<Fact>]
25+
let ``CE taskSeq: use 'do!' with a valuetask<unit>`` () =
26+
let mutable value = 0
3727

38-
// taskSeq { do! ValueTask(task { do value <- value + 1 }) }
28+
taskSeq { do! ValueTask.ofTask (task { do value <- value + 1 }) }
29+
|> verifyEmpty
30+
|> Task.map (fun _ -> value |> should equal 1)
3931

40-
// |> verifyEmpty
32+
[<Fact>]
33+
let ``CE taskSeq: use 'do!' with a non-generic valuetask`` () =
34+
let mutable value = 0
4135

42-
//[<Fact>]
43-
//let ``CE taskSeq: use 'do!' with a non-generic task`` () =
44-
// let mutable value = 0
36+
taskSeq { do! ValueTask(task { do value <- value + 1 }) }
37+
|> verifyEmpty
38+
|> Task.map (fun _ -> value |> should equal 1)
4539

46-
// taskSeq { do! (task { do value <- value + 1 }) |> Task.ignore }
40+
[<Fact>]
41+
let ``CE taskSeq: use 'do!' with a non-generic task`` () =
42+
let mutable value = 0
4743

48-
// |> verifyEmpty
44+
taskSeq { do! (task { do value <- value + 1 }) |> Task.ignore }
45+
|> verifyEmpty
46+
|> Task.map (fun _ -> value |> should equal 1)

src/FSharp.Control.TaskSeq/TaskSeqBuilder.fs

Lines changed: 111 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -505,56 +505,6 @@ type TaskSeqBuilder() =
505505
sm.Data.awaiter <- null
506506
__stack_fin)
507507

508-
member inline _.Bind(task: Task<'TResult1>, continuation: ('TResult1 -> TaskSeqCode<'T>)) : TaskSeqCode<'T> =
509-
TaskSeqCode<'T>(fun sm ->
510-
let mutable awaiter = task.GetAwaiter()
511-
let mutable __stack_fin = true
512-
513-
Debug.logInfo "at Bind"
514-
515-
if not awaiter.IsCompleted then
516-
// This will yield with __stack_fin2 = false
517-
// This will resume with __stack_fin2 = true
518-
let __stack_fin2 = ResumableCode.Yield().Invoke(&sm)
519-
__stack_fin <- __stack_fin2
520-
521-
Debug.logInfo ("at Bind: with __stack_fin = ", __stack_fin)
522-
Debug.logInfo ("at Bind: this.completed = ", sm.Data.completed)
523-
524-
if __stack_fin then
525-
let result = awaiter.GetResult()
526-
(continuation result).Invoke(&sm)
527-
528-
else
529-
Debug.logInfo "at Bind: calling AwaitUnsafeOnCompleted"
530-
531-
sm.Data.awaiter <- awaiter
532-
sm.Data.current <- ValueNone
533-
false)
534-
535-
member inline _.Bind(task: ValueTask<'TResult1>, continuation: ('TResult1 -> TaskSeqCode<'T>)) : TaskSeqCode<'T> =
536-
TaskSeqCode<'T>(fun sm ->
537-
let mutable awaiter = task.GetAwaiter()
538-
let mutable __stack_fin = true
539-
540-
Debug.logInfo "at BindV"
541-
542-
if not awaiter.IsCompleted then
543-
// This will yield with __stack_fin2 = false
544-
// This will resume with __stack_fin2 = true
545-
let __stack_fin2 = ResumableCode.Yield().Invoke(&sm)
546-
__stack_fin <- __stack_fin2
547-
548-
if __stack_fin then
549-
let result = awaiter.GetResult()
550-
(continuation result).Invoke(&sm)
551-
else
552-
Debug.logInfo "at BindV: calling AwaitUnsafeOnCompleted"
553-
554-
sm.Data.awaiter <- awaiter
555-
sm.Data.current <- ValueNone
556-
false)
557-
558508
//
559509
// These "modules of priority" allow for an indecisive F# to resolve
560510
// the proper overload if a single type implements more than one
@@ -567,6 +517,58 @@ type TaskSeqBuilder() =
567517
// (like For depending on Using etc).
568518
//
569519

520+
[<AutoOpen>]
521+
module LowPriority =
522+
type TaskSeqBuilder with
523+
524+
//
525+
// Note: we cannot place _.Bind directly on the type, as the NoEagerXXX attribute
526+
// has no effect, and each use of `do!` will give an overload error (because the
527+
// `TaskLike` type and the `Task<_>` type are interchangeable).
528+
//
529+
// However, we cannot unify these two methods, because Task<_> inherits from Task (non-generic)
530+
// and we need a way to distinguish these two methods.
531+
//
532+
533+
[<NoEagerConstraintApplication>]
534+
member inline _.Bind< ^TaskLike, 'TResult1, 'TResult2, ^Awaiter, 'TOverall
535+
when ^TaskLike: (member GetAwaiter: unit -> ^Awaiter)
536+
and ^Awaiter :> ICriticalNotifyCompletion
537+
and ^Awaiter: (member get_IsCompleted: unit -> bool)
538+
and ^Awaiter: (member GetResult: unit -> 'TResult1)>
539+
(
540+
task: ^TaskLike,
541+
continuation: ('TResult1 -> TaskSeqCode<'TResult2>)
542+
) : TaskSeqCode<'TResult2> =
543+
544+
TaskSeqCode<'TResult2>(fun sm ->
545+
let mutable awaiter = (^TaskLike: (member GetAwaiter: unit -> ^Awaiter) (task))
546+
let mutable __stack_fin = true
547+
548+
Debug.logInfo "at TaskLike bind!"
549+
550+
if not (^Awaiter: (member get_IsCompleted: unit -> bool) (awaiter)) then
551+
// This will yield with __stack_fin2 = false
552+
// This will resume with __stack_fin2 = true
553+
let __stack_fin2 = ResumableCode.Yield().Invoke(&sm)
554+
__stack_fin <- __stack_fin2
555+
556+
Debug.logInfo ("at TaskLike bind!: with __stack_fin = ", __stack_fin)
557+
Debug.logInfo ("at TaskLike bind!: this.completed = ", sm.Data.completed)
558+
559+
if __stack_fin then
560+
Debug.logInfo "at TaskLike bind!: finished awaiting, calling continuation"
561+
let result = (^Awaiter: (member GetResult: unit -> 'TResult1) (awaiter))
562+
(continuation result).Invoke(&sm)
563+
564+
else
565+
Debug.logInfo "at TaskLike bind!: await further"
566+
567+
sm.Data.awaiter <- awaiter
568+
sm.Data.current <- ValueNone
569+
false)
570+
571+
570572
[<AutoOpen>]
571573
module MediumPriority =
572574
type TaskSeqBuilder with
@@ -608,3 +610,62 @@ module MediumPriority =
608610

609611
member inline this.YieldFrom(source: IAsyncEnumerable<'T>) : TaskSeqCode<'T> =
610612
this.For(source, (fun v -> this.Yield(v)))
613+
614+
[<AutoOpen>]
615+
module HighPriority =
616+
type TaskSeqBuilder with
617+
618+
member inline _.Bind(task: Task<'TResult1>, continuation: ('TResult1 -> TaskSeqCode<'T>)) : TaskSeqCode<'T> =
619+
TaskSeqCode<'T>(fun sm ->
620+
let mutable awaiter = task.GetAwaiter()
621+
let mutable __stack_fin = true
622+
623+
Debug.logInfo "at Bind"
624+
625+
if not awaiter.IsCompleted then
626+
// This will yield with __stack_fin2 = false
627+
// This will resume with __stack_fin2 = true
628+
let __stack_fin2 = ResumableCode.Yield().Invoke(&sm)
629+
__stack_fin <- __stack_fin2
630+
631+
Debug.logInfo ("at Bind: with __stack_fin = ", __stack_fin)
632+
Debug.logInfo ("at Bind: this.completed = ", sm.Data.completed)
633+
634+
if __stack_fin then
635+
Debug.logInfo "at Bind: finished awaiting, calling continuation"
636+
let result = awaiter.GetResult()
637+
(continuation result).Invoke(&sm)
638+
639+
else
640+
Debug.logInfo "at Bind: await further"
641+
642+
sm.Data.awaiter <- awaiter
643+
sm.Data.current <- ValueNone
644+
false)
645+
646+
member inline _.Bind
647+
(
648+
task: ValueTask<'TResult1>,
649+
continuation: ('TResult1 -> TaskSeqCode<'T>)
650+
) : TaskSeqCode<'T> =
651+
TaskSeqCode<'T>(fun sm ->
652+
let mutable awaiter = task.GetAwaiter()
653+
let mutable __stack_fin = true
654+
655+
Debug.logInfo "at BindV"
656+
657+
if not awaiter.IsCompleted then
658+
// This will yield with __stack_fin2 = false
659+
// This will resume with __stack_fin2 = true
660+
let __stack_fin2 = ResumableCode.Yield().Invoke(&sm)
661+
__stack_fin <- __stack_fin2
662+
663+
if __stack_fin then
664+
let result = awaiter.GetResult()
665+
(continuation result).Invoke(&sm)
666+
else
667+
Debug.logInfo "at BindV: calling AwaitUnsafeOnCompleted"
668+
669+
sm.Data.awaiter <- awaiter
670+
sm.Data.current <- ValueNone
671+
false)

0 commit comments

Comments
 (0)