Skip to content

Commit b550f3c

Browse files
committed
Add null-checks for each function that takes an IAsyncEnumerable or otherwise nullable type
1 parent a965a76 commit b550f3c

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

src/FSharp.Control.TaskSeq/TaskSeq.fs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ module TaskSeq =
3030
//
3131

3232
let toList (source: taskSeq<'T>) = [
33+
Internal.checkNonNull (nameof source) source
3334
let e = source.GetAsyncEnumerator(CancellationToken())
3435

3536
try
@@ -40,6 +41,7 @@ module TaskSeq =
4041
]
4142

4243
let toArray (source: taskSeq<'T>) = [|
44+
Internal.checkNonNull (nameof source) source
4345
let e = source.GetAsyncEnumerator(CancellationToken())
4446

4547
try
@@ -50,6 +52,7 @@ module TaskSeq =
5052
|]
5153

5254
let toSeq (source: taskSeq<'T>) = seq {
55+
Internal.checkNonNull (nameof source) source
5356
let e = source.GetAsyncEnumerator(CancellationToken())
5457

5558
try
@@ -74,6 +77,8 @@ module TaskSeq =
7477
//
7578

7679
let ofArray (source: 'T[]) = taskSeq {
80+
Internal.checkNonNull (nameof source) source
81+
7782
for c in source do
7883
yield c
7984
}
@@ -84,16 +89,22 @@ module TaskSeq =
8489
}
8590

8691
let ofSeq (source: 'T seq) = taskSeq {
92+
Internal.checkNonNull (nameof source) source
93+
8794
for c in source do
8895
yield c
8996
}
9097

9198
let ofResizeArray (source: 'T ResizeArray) = taskSeq {
99+
Internal.checkNonNull (nameof source) source
100+
92101
for c in source do
93102
yield c
94103
}
95104

96105
let ofTaskSeq (source: #Task<'T> seq) = taskSeq {
106+
Internal.checkNonNull (nameof source) source
107+
97108
for c in source do
98109
let! c = c
99110
yield c
@@ -106,12 +117,16 @@ module TaskSeq =
106117
}
107118

108119
let ofTaskArray (source: #Task<'T> array) = taskSeq {
120+
Internal.checkNonNull (nameof source) source
121+
109122
for c in source do
110123
let! c = c
111124
yield c
112125
}
113126

114127
let ofAsyncSeq (source: Async<'T> seq) = taskSeq {
128+
Internal.checkNonNull (nameof source) source
129+
115130
for c in source do
116131
let! c = task { return! c }
117132
yield c
@@ -124,6 +139,8 @@ module TaskSeq =
124139
}
125140

126141
let ofAsyncArray (source: Async<'T> array) = taskSeq {
142+
Internal.checkNonNull (nameof source) source
143+
127144
for c in source do
128145
let! c = Async.toTask c
129146
yield c
@@ -148,21 +165,29 @@ module TaskSeq =
148165
}
149166

150167
let concat (sources: taskSeq<#taskSeq<'T>>) = taskSeq {
168+
Internal.checkNonNull (nameof sources) sources
169+
151170
for ts in sources do
152171
yield! (ts :> taskSeq<'T>)
153172
}
154173

155174
let append (source1: #taskSeq<'T>) (source2: #taskSeq<'T>) = taskSeq {
175+
Internal.checkNonNull (nameof source1) source1
176+
Internal.checkNonNull (nameof source2) source2
156177
yield! (source1 :> IAsyncEnumerable<'T>)
157178
yield! (source2 :> IAsyncEnumerable<'T>)
158179
}
159180

160181
let appendSeq (source1: #taskSeq<'T>) (source2: #seq<'T>) = taskSeq {
182+
Internal.checkNonNull (nameof source1) source1
183+
Internal.checkNonNull (nameof source2) source2
161184
yield! (source1 :> IAsyncEnumerable<'T>)
162185
yield! (source2 :> seq<'T>)
163186
}
164187

165188
let prependSeq (source1: #seq<'T>) (source2: #taskSeq<'T>) = taskSeq {
189+
Internal.checkNonNull (nameof source1) source1
190+
Internal.checkNonNull (nameof source2) source2
166191
yield! (source1 :> seq<'T>)
167192
yield! (source2 :> IAsyncEnumerable<'T>)
168193
}
@@ -242,6 +267,7 @@ module TaskSeq =
242267
}
243268

244269
let indexed (source: taskSeq<'T>) = taskSeq {
270+
Internal.checkNonNull (nameof source) source
245271
let mutable i = 0
246272

247273
for x in source do

src/FSharp.Control.TaskSeq/TaskSeqInternal.fs

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ type internal InitAction<'T, 'TaskT when 'TaskT :> Task<'T>> =
3939
| InitActionAsync of async_init_item: (int -> 'TaskT)
4040

4141
module internal TaskSeqInternal =
42+
/// Raise an NRE for arguments that are null. Only used for 'source' parameters, never for function parameters.
43+
let inline checkNonNull argName arg =
44+
if isNull arg then
45+
nullArg argName
46+
4247
let inline raiseEmptySeq () =
4348
ArgumentException("The asynchronous input sequence was empty.", "source")
4449
|> raise
@@ -56,6 +61,7 @@ module internal TaskSeqInternal =
5661
|> raise
5762

5863
let isEmpty (source: taskSeq<_>) = task {
64+
checkNonNull (nameof source) source
5965
use e = source.GetAsyncEnumerator(CancellationToken())
6066
let! step = e.MoveNextAsync()
6167
return not step
@@ -88,6 +94,7 @@ module internal TaskSeqInternal =
8894

8995
/// Returns length unconditionally, or based on a predicate
9096
let lengthBy predicate (source: taskSeq<_>) = task {
97+
checkNonNull (nameof source) source
9198
use e = source.GetAsyncEnumerator(CancellationToken())
9299
let mutable go = true
93100
let mutable i = 0
@@ -123,6 +130,7 @@ module internal TaskSeqInternal =
123130

124131
/// Returns length unconditionally, or based on a predicate
125132
let lengthBeforeMax max (source: taskSeq<_>) = task {
133+
checkNonNull (nameof source) source
126134
use e = source.GetAsyncEnumerator(CancellationToken())
127135
let mutable go = true
128136
let mutable i = 0
@@ -138,6 +146,7 @@ module internal TaskSeqInternal =
138146
}
139147

140148
let tryExactlyOne (source: taskSeq<_>) = task {
149+
checkNonNull (nameof source) source
141150
use e = source.GetAsyncEnumerator(CancellationToken())
142151

143152
match! e.MoveNextAsync() with
@@ -194,6 +203,7 @@ module internal TaskSeqInternal =
194203
}
195204

196205
let iter action (source: taskSeq<_>) = task {
206+
checkNonNull (nameof source) source
197207
use e = source.GetAsyncEnumerator(CancellationToken())
198208
let mutable go = true
199209
let! step = e.MoveNextAsync()
@@ -235,6 +245,7 @@ module internal TaskSeqInternal =
235245
}
236246

237247
let fold folder initial (source: taskSeq<_>) = task {
248+
checkNonNull (nameof source) source
238249
use e = source.GetAsyncEnumerator(CancellationToken())
239250
let mutable go = true
240251
let mutable result = initial
@@ -259,44 +270,49 @@ module internal TaskSeqInternal =
259270
}
260271

261272
let toResizeArrayAsync source = task {
273+
checkNonNull (nameof source) source
262274
let res = ResizeArray()
263275
do! source |> iter (SimpleAction(fun item -> res.Add item))
264276
return res
265277
}
266278

267279
let toResizeArrayAndMapAsync mapper source = (toResizeArrayAsync >> Task.map mapper) source
268280

269-
let map mapper (taskSequence: taskSeq<_>) =
281+
let map mapper (source: taskSeq<_>) =
282+
checkNonNull (nameof source) source
283+
270284
match mapper with
271285
| CountableAction mapper -> taskSeq {
272286
let mutable i = 0
273287

274-
for c in taskSequence do
288+
for c in source do
275289
yield mapper i c
276290
i <- i + 1
277291
}
278292

279293
| SimpleAction mapper -> taskSeq {
280-
for c in taskSequence do
294+
for c in source do
281295
yield mapper c
282296
}
283297

284298
| AsyncCountableAction mapper -> taskSeq {
285299
let mutable i = 0
286300

287-
for c in taskSequence do
301+
for c in source do
288302
let! result = mapper i c
289303
yield result
290304
i <- i + 1
291305
}
292306

293307
| AsyncSimpleAction mapper -> taskSeq {
294-
for c in taskSequence do
308+
for c in source do
295309
let! result = mapper c
296310
yield result
297311
}
298312

299313
let zip (source1: taskSeq<_>) (source2: taskSeq<_>) = taskSeq {
314+
checkNonNull (nameof source1) source1
315+
checkNonNull (nameof source2) source2
300316
use e1 = source1.GetAsyncEnumerator(CancellationToken())
301317
use e2 = source2.GetAsyncEnumerator(CancellationToken())
302318
let mutable go = true
@@ -312,28 +328,37 @@ module internal TaskSeqInternal =
312328
}
313329

314330
let collect (binder: _ -> #IAsyncEnumerable<_>) (source: taskSeq<_>) = taskSeq {
331+
checkNonNull (nameof source) source
332+
315333
for c in source do
316334
yield! binder c :> IAsyncEnumerable<_>
317335
}
318336

319337
let collectSeq (binder: _ -> #seq<_>) (source: taskSeq<_>) = taskSeq {
338+
checkNonNull (nameof source) source
339+
320340
for c in source do
321341
yield! binder c :> seq<_>
322342
}
323343

324344
let collectAsync (binder: _ -> #Task<#IAsyncEnumerable<_>>) (source: taskSeq<_>) = taskSeq {
345+
checkNonNull (nameof source) source
346+
325347
for c in source do
326348
let! result = binder c
327349
yield! result :> IAsyncEnumerable<_>
328350
}
329351

330352
let collectSeqAsync (binder: _ -> #Task<#seq<_>>) (source: taskSeq<_>) = taskSeq {
353+
checkNonNull (nameof source) source
354+
331355
for c in source do
332356
let! result = binder c
333357
yield! result :> seq<_>
334358
}
335359

336360
let tryLast (source: taskSeq<_>) = task {
361+
checkNonNull (nameof source) source
337362
use e = source.GetAsyncEnumerator(CancellationToken())
338363
let mutable go = true
339364
let mutable last = ValueNone
@@ -351,6 +376,7 @@ module internal TaskSeqInternal =
351376
}
352377

353378
let tryHead (source: taskSeq<_>) = task {
379+
checkNonNull (nameof source) source
354380
use e = source.GetAsyncEnumerator(CancellationToken())
355381

356382
match! e.MoveNextAsync() with
@@ -359,6 +385,7 @@ module internal TaskSeqInternal =
359385
}
360386

361387
let tryTail (source: taskSeq<_>) = task {
388+
checkNonNull (nameof source) source
362389
use e = source.GetAsyncEnumerator(CancellationToken())
363390

364391
match! e.MoveNextAsync() with
@@ -379,6 +406,8 @@ module internal TaskSeqInternal =
379406
}
380407

381408
let tryItem index (source: taskSeq<_>) = task {
409+
checkNonNull (nameof source) source
410+
382411
if index < 0 then
383412
// while the loop below wouldn't run anyway, we don't want to call MoveNext in this case
384413
// to prevent side effects hitting unnecessarily
@@ -404,6 +433,7 @@ module internal TaskSeqInternal =
404433
}
405434

406435
let tryPick chooser (source: taskSeq<_>) = task {
436+
checkNonNull (nameof source) source
407437
use e = source.GetAsyncEnumerator(CancellationToken())
408438

409439
let mutable go = true
@@ -436,6 +466,7 @@ module internal TaskSeqInternal =
436466
}
437467

438468
let tryFind predicate (source: taskSeq<_>) = task {
469+
checkNonNull (nameof source) source
439470
use e = source.GetAsyncEnumerator(CancellationToken())
440471

441472
let mutable go = true
@@ -472,6 +503,7 @@ module internal TaskSeqInternal =
472503
}
473504

474505
let tryFindIndex predicate (source: taskSeq<_>) = task {
506+
checkNonNull (nameof source) source
475507
use e = source.GetAsyncEnumerator(CancellationToken())
476508

477509
let mutable go = true
@@ -504,6 +536,8 @@ module internal TaskSeqInternal =
504536
}
505537

506538
let choose chooser (source: taskSeq<_>) = taskSeq {
539+
checkNonNull (nameof source) source
540+
507541
match chooser with
508542
| TryPick picker ->
509543
for item in source do
@@ -519,6 +553,8 @@ module internal TaskSeqInternal =
519553
}
520554

521555
let filter predicate (source: taskSeq<_>) = taskSeq {
556+
checkNonNull (nameof source) source
557+
522558
match predicate with
523559
| Predicate predicate ->
524560
for item in source do
@@ -585,6 +621,7 @@ module internal TaskSeqInternal =
585621
ValueTask.CompletedTask
586622

587623
let except itemsToExclude (source: taskSeq<_>) = taskSeq {
624+
checkNonNull (nameof source) source
588625
use e = source.GetAsyncEnumerator(CancellationToken())
589626
let mutable go = true
590627
let! step = e.MoveNextAsync()
@@ -609,6 +646,7 @@ module internal TaskSeqInternal =
609646
}
610647

611648
let exceptOfSeq itemsToExclude (source: taskSeq<_>) = taskSeq {
649+
checkNonNull (nameof source) source
612650
use e = source.GetAsyncEnumerator(CancellationToken())
613651
let mutable go = true
614652
let! step = e.MoveNextAsync()

0 commit comments

Comments
 (0)