@@ -30,8 +30,9 @@ public ConcatMapEager(IAsyncEnumerable<TSource> source, Func<TSource, IAsyncEnum
3030
3131 public IAsyncEnumerator < TResult > GetAsyncEnumerator ( CancellationToken cancellationToken )
3232 {
33- var en = new ConcatMapEagerEnumerator ( _source . GetAsyncEnumerator ( cancellationToken ) , _mapper , _maxConcurrency , _prefetch ,
34- cancellationToken ) ;
33+ var sourceCTS = CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken ) ;
34+ var en = new ConcatMapEagerEnumerator ( _source . GetAsyncEnumerator ( sourceCTS . Token ) , _mapper , _maxConcurrency , _prefetch ,
35+ sourceCTS ) ;
3536 en . MoveNextSource ( ) ;
3637 return en ;
3738 }
@@ -44,7 +45,7 @@ private sealed class ConcatMapEagerEnumerator : IAsyncEnumerator<TResult>
4445
4546 private readonly int _prefetch ;
4647
47- private readonly CancellationToken _ct ;
48+ private readonly CancellationTokenSource _sourceCTS ;
4849
4950 private int _sourceOutstanding ;
5051
@@ -69,7 +70,7 @@ private sealed class ConcatMapEagerEnumerator : IAsyncEnumerator<TResult>
6970
7071 public ConcatMapEagerEnumerator ( IAsyncEnumerator < TSource > source ,
7172 Func < TSource , IAsyncEnumerable < TResult > > mapper , int maxConcurrency , int prefetch ,
72- CancellationToken ct )
73+ CancellationTokenSource cts )
7374 {
7475 _source = source ;
7576 _mapper = mapper ;
@@ -78,11 +79,12 @@ public ConcatMapEagerEnumerator(IAsyncEnumerator<TSource> source,
7879 _disposeWip = 1 ;
7980 _inners = new ConcurrentQueue < InnerHandler > ( ) ;
8081 _disposeTask = new TaskCompletionSource < bool > ( ) ;
81- _ct = ct ;
82+ _sourceCTS = cts ;
8283 }
8384
8485 public ValueTask DisposeAsync ( )
8586 {
87+ _sourceCTS . Cancel ( ) ;
8688 _disposeRequested = true ;
8789 if ( Interlocked . Increment ( ref _sourceDisposeWip ) == 1 )
8890 {
@@ -175,6 +177,15 @@ private bool TryDispose()
175177
176178 private void NextHandler ( Task < bool > t )
177179 {
180+ if ( t . IsCanceled )
181+ {
182+ ExceptionHelper . AddException ( ref _error , new OperationCanceledException ( ) ) ;
183+ _sourceDone = true ;
184+ if ( TryDispose ( ) )
185+ {
186+ ResumeHelper . Resume ( ref _resume ) ;
187+ }
188+ } else
178189 if ( t . IsFaulted )
179190 {
180191 ExceptionHelper . AddException ( ref _error , ExceptionHelper . Extract ( t . Exception ) ) ;
@@ -186,7 +197,7 @@ private void NextHandler(Task<bool> t)
186197 }
187198 else if ( t . Result )
188199 {
189- var cts = CancellationTokenSource . CreateLinkedTokenSource ( _ct ) ;
200+ var cts = CancellationTokenSource . CreateLinkedTokenSource ( _sourceCTS . Token ) ;
190201 IAsyncEnumerator < TResult > src ;
191202 try
192203 {
@@ -312,7 +323,15 @@ private bool TryDispose()
312323
313324 private void InnerNextHandler ( Task < bool > t )
314325 {
315- if ( t . IsFaulted )
326+ if ( t . IsCanceled )
327+ {
328+ ExceptionHelper . AddException ( ref _parent . _error , new OperationCanceledException ( ) ) ;
329+ Done = true ;
330+ if ( TryDispose ( ) )
331+ {
332+ ResumeHelper . Resume ( ref _parent . _resume ) ;
333+ }
334+ } else if ( t . IsFaulted )
316335 {
317336 ExceptionHelper . AddException ( ref _parent . _error , ExceptionHelper . Extract ( t . Exception ) ) ;
318337 Done = true ;
0 commit comments