Skip to content

Commit d5d16e3

Browse files
authored
Test cancel connection request (#439)
1 parent f0bfba7 commit d5d16e3

File tree

2 files changed

+168
-1
lines changed

2 files changed

+168
-1
lines changed

Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,64 @@ final class ConnectionPoolTests: XCTestCase {
368368
}
369369
}
370370

371-
}
371+
func testCancelConnectionRequestWorks() async throws {
372+
let clock = MockClock()
373+
let factory = MockConnectionFactory<MockClock>()
374+
let keepAliveDuration = Duration.seconds(30)
375+
let keepAlive = MockPingPongBehavior(keepAliveFrequency: keepAliveDuration, connectionType: MockConnection.self)
376+
377+
var mutableConfig = ConnectionPoolConfiguration()
378+
mutableConfig.minimumConnectionCount = 0
379+
mutableConfig.maximumConnectionSoftLimit = 4
380+
mutableConfig.maximumConnectionHardLimit = 4
381+
mutableConfig.idleTimeout = .seconds(10)
382+
let config = mutableConfig
383+
384+
let pool = ConnectionPool(
385+
configuration: config,
386+
idGenerator: ConnectionIDGenerator(),
387+
requestType: ConnectionRequest<MockConnection>.self,
388+
keepAliveBehavior: keepAlive,
389+
observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self),
390+
clock: clock
391+
) {
392+
try await factory.makeConnection(id: $0, for: $1)
393+
}
372394

395+
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
396+
taskGroup.addTask {
397+
await pool.run()
398+
}
399+
400+
let leaseTask = Task {
401+
_ = try await pool.leaseConnection()
402+
}
403+
404+
let connectionAttemptWaiter = Waiter(of: Void.self)
405+
406+
taskGroup.addTask {
407+
try await factory.nextConnectAttempt { connectionID in
408+
connectionAttemptWaiter.yield(value: ())
409+
throw CancellationError()
410+
}
411+
}
412+
413+
try await connectionAttemptWaiter.result
414+
leaseTask.cancel()
415+
416+
let taskResult = await leaseTask.result
417+
switch taskResult {
418+
case .success:
419+
XCTFail("Expected task failure")
420+
case .failure(let failure):
421+
XCTAssertEqual(failure as? ConnectionPoolError, .requestCancelled)
422+
}
423+
424+
taskGroup.cancelAll()
425+
for connection in factory.runningConnections {
426+
connection.closeIfClosing()
427+
}
428+
}
429+
}
430+
}
373431

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import Atomics
2+
@testable import _ConnectionPoolModule
3+
4+
final class Waiter<Result: Sendable>: Sendable {
5+
struct State: Sendable {
6+
7+
var result: Swift.Result<Result, any Error>? = nil
8+
var continuations: [(Int, CheckedContinuation<Result, any Error>)] = []
9+
10+
}
11+
12+
let waiterID = ManagedAtomic(0)
13+
let stateBox: NIOLockedValueBox<State> = NIOLockedValueBox(State())
14+
15+
init(of: Result.Type) {}
16+
17+
enum GetAction {
18+
case fail(any Error)
19+
case succeed(Result)
20+
case none
21+
}
22+
23+
var result: Result {
24+
get async throws {
25+
let waiterID = self.waiterID.loadThenWrappingIncrement(ordering: .relaxed)
26+
27+
return try await withTaskCancellationHandler {
28+
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Result, any Error>) in
29+
let action = self.stateBox.withLockedValue { state -> GetAction in
30+
if Task.isCancelled {
31+
return .fail(CancellationError())
32+
}
33+
34+
switch state.result {
35+
case .none:
36+
state.continuations.append((waiterID, continuation))
37+
return .none
38+
39+
case .success(let result):
40+
return .succeed(result)
41+
42+
case .failure(let error):
43+
return .fail(error)
44+
}
45+
}
46+
47+
switch action {
48+
case .fail(let error):
49+
continuation.resume(throwing: error)
50+
51+
case .succeed(let result):
52+
continuation.resume(returning: result)
53+
54+
case .none:
55+
break
56+
}
57+
}
58+
} onCancel: {
59+
let cont = self.stateBox.withLockedValue { state -> CheckedContinuation<Result, any Error>? in
60+
guard state.result == nil else { return nil }
61+
62+
guard let contIndex = state.continuations.firstIndex(where: { $0.0 == waiterID }) else {
63+
return nil
64+
}
65+
let (_, continuation) = state.continuations.remove(at: contIndex)
66+
return continuation
67+
}
68+
69+
cont?.resume(throwing: CancellationError())
70+
}
71+
}
72+
}
73+
74+
func yield(value: Result) {
75+
let continuations = self.stateBox.withLockedValue { state in
76+
guard state.result == nil else {
77+
return [(Int, CheckedContinuation<Result, any Error>)]().lazy.map(\.1)
78+
}
79+
state.result = .success(value)
80+
81+
let continuations = state.continuations
82+
state.continuations = []
83+
84+
return continuations.lazy.map(\.1)
85+
}
86+
87+
for continuation in continuations {
88+
continuation.resume(returning: value)
89+
}
90+
}
91+
92+
func yield(error: any Error) {
93+
let continuations = self.stateBox.withLockedValue { state in
94+
guard state.result == nil else {
95+
return [(Int, CheckedContinuation<Result, any Error>)]().lazy.map(\.1)
96+
}
97+
state.result = .failure(error)
98+
99+
let continuations = state.continuations
100+
state.continuations = []
101+
102+
return continuations.lazy.map(\.1)
103+
}
104+
105+
for continuation in continuations {
106+
continuation.resume(throwing: error)
107+
}
108+
}
109+
}

0 commit comments

Comments
 (0)