@@ -75,6 +75,7 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
7575 var buffer = [ Element] ( )
7676 var finished = false
7777 var failure : Failure ?
78+ var cancelled = false
7879 var limit : CheckedContinuation < Bool , Never > ?
7980 var demand : CheckedContinuation < Void , Never > ?
8081
@@ -155,14 +156,17 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
155156 }
156157
157158 func cancel( ) {
158- // TODO: this currently is a hard cancel, it should be refined to only cancel when everything is terminal
159159 let ( task, limit, demand, cancelled) = state. withLock { state -> ( IteratingTask ? , CheckedContinuation < Bool , Never > ? , CheckedContinuation < Void , Never > ? , Bool ) in
160- defer {
161- state. iteratingTask = . cancelled
162- state. limit = nil
163- state. demand = nil
160+ if state. sides. count == 0 {
161+ defer {
162+ state. iteratingTask = . cancelled
163+ state. cancelled = true
164+ }
165+ return state. emit ( state. iteratingTask)
166+ } else {
167+ state. cancelled = true
168+ return state. emit ( nil )
164169 }
165- return state. emit ( state. iteratingTask)
166170 }
167171 task? . cancel ( )
168172 limit? . resume ( returning: cancelled)
@@ -178,21 +182,32 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
178182 }
179183
180184 func unregisterSide( _ id: Int ) {
181- let ( side, continuation, cancelled) = state. withLock { state -> ( Side . State ? , CheckedContinuation < Bool , Never > ? , Bool ) in
185+ let ( side, continuation, cancelled, iteratingTaskToCancel ) = state. withLock { state -> ( Side . State ? , CheckedContinuation < Bool , Never > ? , Bool , IteratingTask ? ) in
182186 let side = state. sides. removeValue ( forKey: id)
183187 state. trimBuffer ( )
188+ let cancelRequested = state. sides. count == 0 && state. cancelled
184189 if let limit, state. buffer. count < limit {
185190 defer { state. limit = nil }
186191 if case . cancelled = state. iteratingTask {
187- return ( side, state. limit, true )
192+ return ( side, state. limit, true , nil )
188193 } else {
189- return ( side, state. limit, false )
194+ defer {
195+ if cancelRequested {
196+ state. iteratingTask = . cancelled
197+ }
198+ }
199+ return ( side, state. limit, false , cancelRequested ? state. iteratingTask : nil )
190200 }
191201 } else {
192202 if case . cancelled = state. iteratingTask {
193- return ( side, nil , true )
203+ return ( side, nil , true , nil )
194204 } else {
195- return ( side, nil , false )
205+ defer {
206+ if cancelRequested {
207+ state. iteratingTask = . cancelled
208+ }
209+ }
210+ return ( side, nil , false , cancelRequested ? state. iteratingTask : nil )
196211 }
197212 }
198213 }
@@ -202,6 +217,9 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
202217 if let side {
203218 side. continuaton? . resume ( returning: . success( nil ) )
204219 }
220+ if let iteratingTaskToCancel {
221+ iteratingTaskToCancel. cancel ( )
222+ }
205223 }
206224
207225 func iterate( ) async -> Bool {
0 commit comments