Skip to content

Commit 33c4c10

Browse files
committed
Add null-checks for each function that takes an IAsyncEnumerable or otherwise nullable type
1 parent 3e95bc5 commit 33c4c10

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

src/FSharp.Control.TaskSeq/TaskSeq.fs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ module TaskSeq =
3030
//
3131

3232
let toList (source: taskSeq<'T>) = [
33+
Internal.checkNonNull (nameof source) source
3334
let e = source.GetAsyncEnumerator(CancellationToken())
3435

3536
try
@@ -40,6 +41,7 @@ module TaskSeq =
4041
]
4142

4243
let toArray (source: taskSeq<'T>) = [|
44+
Internal.checkNonNull (nameof source) source
4345
let e = source.GetAsyncEnumerator(CancellationToken())
4446

4547
try
@@ -50,6 +52,7 @@ module TaskSeq =
5052
|]
5153

5254
let toSeq (source: taskSeq<'T>) = seq {
55+
Internal.checkNonNull (nameof source) source
5356
let e = source.GetAsyncEnumerator(CancellationToken())
5457

5558
try
@@ -74,6 +77,8 @@ module TaskSeq =
7477
//
7578

7679
let ofArray (source: 'T[]) = taskSeq {
80+
Internal.checkNonNull (nameof source) source
81+
7782
for c in source do
7883
yield c
7984
}
@@ -84,16 +89,22 @@ module TaskSeq =
8489
}
8590

8691
let ofSeq (source: 'T seq) = taskSeq {
92+
Internal.checkNonNull (nameof source) source
93+
8794
for c in source do
8895
yield c
8996
}
9097

9198
let ofResizeArray (source: 'T ResizeArray) = taskSeq {
99+
Internal.checkNonNull (nameof source) source
100+
92101
for c in source do
93102
yield c
94103
}
95104

96105
let ofTaskSeq (source: #Task<'T> seq) = taskSeq {
106+
Internal.checkNonNull (nameof source) source
107+
97108
for c in source do
98109
let! c = c
99110
yield c
@@ -106,12 +117,16 @@ module TaskSeq =
106117
}
107118

108119
let ofTaskArray (source: #Task<'T> array) = taskSeq {
120+
Internal.checkNonNull (nameof source) source
121+
109122
for c in source do
110123
let! c = c
111124
yield c
112125
}
113126

114127
let ofAsyncSeq (source: Async<'T> seq) = taskSeq {
128+
Internal.checkNonNull (nameof source) source
129+
115130
for c in source do
116131
let! c = task { return! c }
117132
yield c
@@ -124,6 +139,8 @@ module TaskSeq =
124139
}
125140

126141
let ofAsyncArray (source: Async<'T> array) = taskSeq {
142+
Internal.checkNonNull (nameof source) source
143+
127144
for c in source do
128145
let! c = Async.toTask c
129146
yield c
@@ -148,21 +165,29 @@ module TaskSeq =
148165
}
149166

150167
let concat (sources: taskSeq<#taskSeq<'T>>) = taskSeq {
168+
Internal.checkNonNull (nameof sources) sources
169+
151170
for ts in sources do
152171
yield! (ts :> taskSeq<'T>)
153172
}
154173

155174
let append (source1: #taskSeq<'T>) (source2: #taskSeq<'T>) = taskSeq {
175+
Internal.checkNonNull (nameof source1) source1
176+
Internal.checkNonNull (nameof source2) source2
156177
yield! (source1 :> IAsyncEnumerable<'T>)
157178
yield! (source2 :> IAsyncEnumerable<'T>)
158179
}
159180

160181
let appendSeq (source1: #taskSeq<'T>) (source2: #seq<'T>) = taskSeq {
182+
Internal.checkNonNull (nameof source1) source1
183+
Internal.checkNonNull (nameof source2) source2
161184
yield! (source1 :> IAsyncEnumerable<'T>)
162185
yield! (source2 :> seq<'T>)
163186
}
164187

165188
let prependSeq (source1: #seq<'T>) (source2: #taskSeq<'T>) = taskSeq {
189+
Internal.checkNonNull (nameof source1) source1
190+
Internal.checkNonNull (nameof source2) source2
166191
yield! (source1 :> seq<'T>)
167192
yield! (source2 :> IAsyncEnumerable<'T>)
168193
}
@@ -242,6 +267,7 @@ module TaskSeq =
242267
}
243268

244269
let indexed (source: taskSeq<'T>) = taskSeq {
270+
Internal.checkNonNull (nameof source) source
245271
let mutable i = 0
246272

247273
for x in source do

src/FSharp.Control.TaskSeq/TaskSeqInternal.fs

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ type internal InitAction<'T, 'TaskT when 'TaskT :> Task<'T>> =
4444
| InitActionAsync of async_init_item: (int -> 'TaskT)
4545

4646
module internal TaskSeqInternal =
47+
/// Raise an NRE for arguments that are null. Only used for 'source' parameters, never for function parameters.
48+
let inline checkNonNull argName arg =
49+
if isNull arg then
50+
nullArg argName
51+
4752
let inline raiseEmptySeq () =
4853
ArgumentException("The asynchronous input sequence was empty.", "source")
4954
|> raise
@@ -61,6 +66,7 @@ module internal TaskSeqInternal =
6166
|> raise
6267

6368
let isEmpty (source: taskSeq<_>) = task {
69+
checkNonNull (nameof source) source
6470
use e = source.GetAsyncEnumerator(CancellationToken())
6571
let! step = e.MoveNextAsync()
6672
return not step
@@ -93,6 +99,7 @@ module internal TaskSeqInternal =
9399

94100
/// Returns length unconditionally, or based on a predicate
95101
let lengthBy predicate (source: taskSeq<_>) = task {
102+
checkNonNull (nameof source) source
96103
use e = source.GetAsyncEnumerator(CancellationToken())
97104
let mutable go = true
98105
let mutable i = 0
@@ -128,6 +135,7 @@ module internal TaskSeqInternal =
128135

129136
/// Returns length unconditionally, or based on a predicate
130137
let lengthBeforeMax max (source: taskSeq<_>) = task {
138+
checkNonNull (nameof source) source
131139
use e = source.GetAsyncEnumerator(CancellationToken())
132140
let mutable go = true
133141
let mutable i = 0
@@ -143,6 +151,7 @@ module internal TaskSeqInternal =
143151
}
144152

145153
let tryExactlyOne (source: taskSeq<_>) = task {
154+
checkNonNull (nameof source) source
146155
use e = source.GetAsyncEnumerator(CancellationToken())
147156

148157
match! e.MoveNextAsync() with
@@ -199,6 +208,7 @@ module internal TaskSeqInternal =
199208
}
200209

201210
let iter action (source: taskSeq<_>) = task {
211+
checkNonNull (nameof source) source
202212
use e = source.GetAsyncEnumerator(CancellationToken())
203213
let mutable go = true
204214
let! step = e.MoveNextAsync()
@@ -240,6 +250,7 @@ module internal TaskSeqInternal =
240250
}
241251

242252
let fold folder initial (source: taskSeq<_>) = task {
253+
checkNonNull (nameof source) source
243254
use e = source.GetAsyncEnumerator(CancellationToken())
244255
let mutable go = true
245256
let mutable result = initial
@@ -264,44 +275,49 @@ module internal TaskSeqInternal =
264275
}
265276

266277
let toResizeArrayAsync source = task {
278+
checkNonNull (nameof source) source
267279
let res = ResizeArray()
268280
do! source |> iter (SimpleAction(fun item -> res.Add item))
269281
return res
270282
}
271283

272284
let toResizeArrayAndMapAsync mapper source = (toResizeArrayAsync >> Task.map mapper) source
273285

274-
let map mapper (taskSequence: taskSeq<_>) =
286+
let map mapper (source: taskSeq<_>) =
287+
checkNonNull (nameof source) source
288+
275289
match mapper with
276290
| CountableAction mapper -> taskSeq {
277291
let mutable i = 0
278292

279-
for c in taskSequence do
293+
for c in source do
280294
yield mapper i c
281295
i <- i + 1
282296
}
283297

284298
| SimpleAction mapper -> taskSeq {
285-
for c in taskSequence do
299+
for c in source do
286300
yield mapper c
287301
}
288302

289303
| AsyncCountableAction mapper -> taskSeq {
290304
let mutable i = 0
291305

292-
for c in taskSequence do
306+
for c in source do
293307
let! result = mapper i c
294308
yield result
295309
i <- i + 1
296310
}
297311

298312
| AsyncSimpleAction mapper -> taskSeq {
299-
for c in taskSequence do
313+
for c in source do
300314
let! result = mapper c
301315
yield result
302316
}
303317

304318
let zip (source1: taskSeq<_>) (source2: taskSeq<_>) = taskSeq {
319+
checkNonNull (nameof source1) source1
320+
checkNonNull (nameof source2) source2
305321
use e1 = source1.GetAsyncEnumerator(CancellationToken())
306322
use e2 = source2.GetAsyncEnumerator(CancellationToken())
307323
let mutable go = true
@@ -317,28 +333,37 @@ module internal TaskSeqInternal =
317333
}
318334

319335
let collect (binder: _ -> #IAsyncEnumerable<_>) (source: taskSeq<_>) = taskSeq {
336+
checkNonNull (nameof source) source
337+
320338
for c in source do
321339
yield! binder c :> IAsyncEnumerable<_>
322340
}
323341

324342
let collectSeq (binder: _ -> #seq<_>) (source: taskSeq<_>) = taskSeq {
343+
checkNonNull (nameof source) source
344+
325345
for c in source do
326346
yield! binder c :> seq<_>
327347
}
328348

329349
let collectAsync (binder: _ -> #Task<#IAsyncEnumerable<_>>) (source: taskSeq<_>) = taskSeq {
350+
checkNonNull (nameof source) source
351+
330352
for c in source do
331353
let! result = binder c
332354
yield! result :> IAsyncEnumerable<_>
333355
}
334356

335357
let collectSeqAsync (binder: _ -> #Task<#seq<_>>) (source: taskSeq<_>) = taskSeq {
358+
checkNonNull (nameof source) source
359+
336360
for c in source do
337361
let! result = binder c
338362
yield! result :> seq<_>
339363
}
340364

341365
let tryLast (source: taskSeq<_>) = task {
366+
checkNonNull (nameof source) source
342367
use e = source.GetAsyncEnumerator(CancellationToken())
343368
let mutable go = true
344369
let mutable last = ValueNone
@@ -356,6 +381,7 @@ module internal TaskSeqInternal =
356381
}
357382

358383
let tryHead (source: taskSeq<_>) = task {
384+
checkNonNull (nameof source) source
359385
use e = source.GetAsyncEnumerator(CancellationToken())
360386

361387
match! e.MoveNextAsync() with
@@ -364,6 +390,7 @@ module internal TaskSeqInternal =
364390
}
365391

366392
let tryTail (source: taskSeq<_>) = task {
393+
checkNonNull (nameof source) source
367394
use e = source.GetAsyncEnumerator(CancellationToken())
368395

369396
match! e.MoveNextAsync() with
@@ -384,6 +411,8 @@ module internal TaskSeqInternal =
384411
}
385412

386413
let tryItem index (source: taskSeq<_>) = task {
414+
checkNonNull (nameof source) source
415+
387416
if index < 0 then
388417
// while the loop below wouldn't run anyway, we don't want to call MoveNext in this case
389418
// to prevent side effects hitting unnecessarily
@@ -409,6 +438,7 @@ module internal TaskSeqInternal =
409438
}
410439

411440
let tryPick chooser (source: taskSeq<_>) = task {
441+
checkNonNull (nameof source) source
412442
use e = source.GetAsyncEnumerator(CancellationToken())
413443

414444
let mutable go = true
@@ -441,6 +471,7 @@ module internal TaskSeqInternal =
441471
}
442472

443473
let tryFind predicate (source: taskSeq<_>) = task {
474+
checkNonNull (nameof source) source
444475
use e = source.GetAsyncEnumerator(CancellationToken())
445476

446477
let mutable go = true
@@ -477,6 +508,7 @@ module internal TaskSeqInternal =
477508
}
478509

479510
let tryFindIndex predicate (source: taskSeq<_>) = task {
511+
checkNonNull (nameof source) source
480512
use e = source.GetAsyncEnumerator(CancellationToken())
481513

482514
let mutable go = true
@@ -509,6 +541,8 @@ module internal TaskSeqInternal =
509541
}
510542

511543
let choose chooser (source: taskSeq<_>) = taskSeq {
544+
checkNonNull (nameof source) source
545+
512546
match chooser with
513547
| TryPick picker ->
514548
for item in source do
@@ -524,6 +558,8 @@ module internal TaskSeqInternal =
524558
}
525559

526560
let filter predicate (source: taskSeq<_>) = taskSeq {
561+
checkNonNull (nameof source) source
562+
527563
match predicate with
528564
| Predicate predicate ->
529565
for item in source do
@@ -642,6 +678,7 @@ module internal TaskSeqInternal =
642678
ValueTask.CompletedTask
643679

644680
let except itemsToExclude (source: taskSeq<_>) = taskSeq {
681+
checkNonNull (nameof source) source
645682
use e = source.GetAsyncEnumerator(CancellationToken())
646683
let mutable go = true
647684
let! step = e.MoveNextAsync()
@@ -666,6 +703,7 @@ module internal TaskSeqInternal =
666703
}
667704

668705
let exceptOfSeq itemsToExclude (source: taskSeq<_>) = taskSeq {
706+
checkNonNull (nameof source) source
669707
use e = source.GetAsyncEnumerator(CancellationToken())
670708
let mutable go = true
671709
let! step = e.MoveNextAsync()

0 commit comments

Comments
 (0)