Skip to content

Commit c8049b1

Browse files
committed
remove promise passing from send ssl and establish ssl actions
1 parent e73443b commit c8049b1

File tree

4 files changed

+55
-62
lines changed

4 files changed

+55
-62
lines changed

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ struct ConnectionStateMachine {
6464

6565
case read
6666
case wait
67-
case sendSSLRequest(EventLoopPromise<Void>?)
68-
case establishSSLConnection(EventLoopPromise<Void>?)
67+
case sendSSLRequest
68+
case establishSSLConnection
6969
case provideAuthenticationContext
7070
case forwardNotificationToListeners(PostgresBackendMessage.NotificationResponse)
7171
case fireEventReadyForQuery
@@ -131,7 +131,7 @@ struct ConnectionStateMachine {
131131
case require
132132
}
133133

134-
mutating func connected(tls: TLSConfiguration, promise: EventLoopPromise<Void>?) -> ConnectionAction {
134+
mutating func connected(tls: TLSConfiguration) -> ConnectionAction {
135135
switch self.state {
136136
case .initialized:
137137
switch tls {
@@ -141,11 +141,11 @@ struct ConnectionStateMachine {
141141

142142
case .prefer:
143143
self.state = .sslRequestSent(.prefer)
144-
return .sendSSLRequest(promise)
144+
return .sendSSLRequest
145145

146146
case .require:
147147
self.state = .sslRequestSent(.require)
148-
return .sendSSLRequest(promise)
148+
return .sendSSLRequest
149149
}
150150

151151
case .sslRequestSent,
@@ -164,11 +164,8 @@ struct ConnectionStateMachine {
164164
}
165165
}
166166

167-
mutating func provideAuthenticationContext(
168-
_ authContext: AuthContext,
169-
promise: EventLoopPromise<Void>?
170-
) -> ConnectionAction {
171-
self.startAuthentication(authContext, promise: promise)
167+
mutating func provideAuthenticationContext(_ authContext: AuthContext) -> ConnectionAction {
168+
self.startAuthentication(authContext)
172169
}
173170

174171
mutating func gracefulClose(_ promise: EventLoopPromise<Void>?) -> ConnectionAction {
@@ -229,14 +226,14 @@ struct ConnectionStateMachine {
229226
}
230227
}
231228

232-
mutating func sslSupportedReceived(unprocessedBytes: Int, promise: EventLoopPromise<Void>?) -> ConnectionAction {
229+
mutating func sslSupportedReceived(unprocessedBytes: Int) -> ConnectionAction {
233230
switch self.state {
234231
case .sslRequestSent:
235232
if unprocessedBytes > 0 {
236-
return self.closeConnectionAndCleanup(.receivedUnencryptedDataAfterSSLRequest, closePromise: promise)
233+
return self.closeConnectionAndCleanup(.receivedUnencryptedDataAfterSSLRequest)
237234
}
238235
self.state = .sslNegotiated
239-
return .establishSSLConnection(promise)
236+
return .establishSSLConnection
240237

241238
case .initialized,
242239
.sslNegotiated,
@@ -249,7 +246,7 @@ struct ConnectionStateMachine {
249246
.closeCommand,
250247
.closing,
251248
.closed:
252-
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported), closePromise: promise)
249+
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported))
253250

254251
case .modifying:
255252
preconditionFailure("Invalid state: \(self.state)")
@@ -805,10 +802,7 @@ struct ConnectionStateMachine {
805802

806803
// MARK: - Private Methods -
807804

808-
private mutating func startAuthentication(
809-
_ authContext: AuthContext,
810-
promise: EventLoopPromise<Void>?
811-
) -> ConnectionAction {
805+
private mutating func startAuthentication(_ authContext: AuthContext) -> ConnectionAction {
812806
guard case .waitingToStartAuthentication = self.state else {
813807
preconditionFailure("Can only start authentication after connect or ssl establish")
814808
}

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
6464
self.encoder = PostgresFrontendMessageEncoder(buffer: context.channel.allocator.buffer(capacity: 256))
6565

6666
if context.channel.isActive {
67-
self.connected(context: context, writePromise: nil)
67+
self.connected(context: context)
6868
}
6969
}
7070

@@ -80,7 +80,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
8080
// it receives a
8181
context.fireChannelActive()
8282

83-
self.connected(context: context, writePromise: nil)
83+
self.connected(context: context)
8484
}
8585

8686
func channelInactive(context: ChannelHandlerContext) {
@@ -161,7 +161,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
161161
case .rowDescription(let rowDescription):
162162
action = self.state.rowDescriptionReceived(rowDescription)
163163
case .sslSupported:
164-
action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes, promise: nil)
164+
action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes)
165165
case .sslUnsupported:
166166
action = self.state.sslUnsupportedReceived()
167167
}
@@ -285,8 +285,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
285285

286286
switch event {
287287
case PSQLOutgoingEvent.authenticate(let authContext):
288-
let action = self.state.provideAuthenticationContext(authContext, promise: promise)
288+
let action = self.state.provideAuthenticationContext(authContext)
289289
self.run(action, with: context)
290+
promise?.succeed(())
290291

291292
case PSQLOutgoingEvent.gracefulShutdown:
292293
let action = self.state.gracefulClose(promise)
@@ -333,16 +334,16 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
333334
self.logger.trace("Run action", metadata: [.connectionAction: "\(action)"])
334335

335336
switch action {
336-
case .establishSSLConnection(let promise):
337-
self.establishSSLConnection(context: context, promise: promise)
337+
case .establishSSLConnection:
338+
self.establishSSLConnection(context: context)
338339
case .read:
339340
context.read()
340341
case .wait:
341342
break
342343
case .sendStartupMessage(let authContext):
343344
self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters)
344345
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
345-
case .sendSSLRequest(let promise):
346+
case .sendSSLRequest:
346347
self.encoder.ssl()
347348
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
348349
case .sendPasswordMessage(let mode, let authContext):
@@ -410,7 +411,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
410411
database: self.configuration.database,
411412
additionalParameters: self.configuration.options.additionalStartupParameters
412413
)
413-
let action = self.state.provideAuthenticationContext(authContext, promise: nil)
414+
let action = self.state.provideAuthenticationContext(authContext)
414415
return self.run(action, with: context)
415416
}
416417
case .fireEventReadyForQuery:
@@ -447,23 +448,21 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
447448

448449
// MARK: - Private Methods -
449450

450-
private func connected(context: ChannelHandlerContext, writePromise: EventLoopPromise<Void>?) {
451-
let action = self.state.connected(tls: .init(self.configuration.tls), promise: writePromise)
451+
private func connected(context: ChannelHandlerContext) {
452+
let action = self.state.connected(tls: .init(self.configuration.tls))
452453
self.run(action, with: context)
453454
}
454455

455-
private func establishSSLConnection(context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
456+
private func establishSSLConnection(context: ChannelHandlerContext) {
456457
// This method must only be called, if we signalized the StateMachine before that we are
457458
// able to setup a SSL connection.
458459
do {
459460
try self.configureSSLCallback!(context.channel)
460461
let action = self.state.sslHandlerAdded()
461462
self.run(action, with: context)
462-
promise?.succeed()
463463
} catch {
464464
let action = self.state.errorHappened(.failedToAddSSLHandler(underlying: error))
465465
self.run(action, with: context)
466-
promise?.fail(error)
467466
}
468467
}
469468

Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,48 @@ class AuthenticationStateMachineTests: XCTestCase {
88
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
99

1010
var state = ConnectionStateMachine(requireBackendKeyData: true)
11-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
11+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
1212

13-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
13+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
1414
XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext))
1515
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
1616
}
1717

1818
func testAuthenticateMD5() {
1919
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
2020
var state = ConnectionStateMachine(requireBackendKeyData: true)
21-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
21+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
2222
let salt: UInt32 = 0x00_01_02_03
2323

24-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
24+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
2525
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
2626
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
2727
}
2828

2929
func testAuthenticateMD5WithoutPassword() {
3030
let authContext = AuthContext(username: "test", password: nil, database: "test")
3131
var state = ConnectionStateMachine(requireBackendKeyData: true)
32-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
32+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
3333
let salt: UInt32 = 0x00_01_02_03
3434

35-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
35+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
3636
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)),
3737
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .authMechanismRequiresPassword, closePromise: nil)))
3838
}
3939

4040
func testAuthenticateOkAfterStartUpWithoutAuthChallenge() {
4141
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
4242
var state = ConnectionStateMachine(requireBackendKeyData: true)
43-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
44-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
43+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
44+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
4545
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
4646
}
4747

4848
func testAuthenticateSCRAMSHA256WithAtypicalEncoding() {
4949
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
5050
var state = ConnectionStateMachine(requireBackendKeyData: true)
51-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
52-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
51+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
52+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
5353

5454
let saslResponse = state.authenticationMessageReceived(.sasl(names: ["SCRAM-SHA-256"]))
5555
guard case .sendSaslInitialResponse(name: let name, initialResponse: let responseData) = saslResponse else {
@@ -72,10 +72,10 @@ class AuthenticationStateMachineTests: XCTestCase {
7272
func testAuthenticationFailure() {
7373
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
7474
var state = ConnectionStateMachine(requireBackendKeyData: true)
75-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
75+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
7676
let salt: UInt32 = 0x00_01_02_03
7777

78-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
78+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
7979
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
8080
let fields: [PostgresBackendMessage.Field: String] = [
8181
.message: "password authentication failed for user \"postgres\"",
@@ -104,8 +104,8 @@ class AuthenticationStateMachineTests: XCTestCase {
104104
for (message, mechanism) in unsupported {
105105
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
106106
var state = ConnectionStateMachine(requireBackendKeyData: true)
107-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
108-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
107+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
108+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
109109
XCTAssertEqual(state.authenticationMessageReceived(message),
110110
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unsupportedAuthMechanism(mechanism), closePromise: nil)))
111111
}
@@ -123,8 +123,8 @@ class AuthenticationStateMachineTests: XCTestCase {
123123
for message in unexpected {
124124
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
125125
var state = ConnectionStateMachine(requireBackendKeyData: true)
126-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
127-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
126+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
127+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
128128
XCTAssertEqual(state.authenticationMessageReceived(message),
129129
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil)))
130130
}
@@ -150,8 +150,8 @@ class AuthenticationStateMachineTests: XCTestCase {
150150
for message in unexpected {
151151
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
152152
var state = ConnectionStateMachine(requireBackendKeyData: true)
153-
XCTAssertEqual(state.connected(tls: .disable, promise: nil), .provideAuthenticationContext)
154-
XCTAssertEqual(state.provideAuthenticationContext(authContext, promise: nil), .sendStartupMessage(authContext))
153+
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
154+
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
155155
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
156156
XCTAssertEqual(state.authenticationMessageReceived(message),
157157
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil)))

0 commit comments

Comments
 (0)