Skip to content

Commit a48eebc

Browse files
authored
Actually use additional connection parameters (vapor#473)
1 parent ee5d5e1 commit a48eebc

File tree

4 files changed

+82
-5
lines changed

4 files changed

+82
-5
lines changed

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
390390
let authContext = AuthContext(
391391
username: username,
392392
password: self.configuration.password,
393-
database: self.configuration.database
393+
database: self.configuration.database,
394+
additionalParameters: self.configuration.options.additionalStartupParameters
394395
)
395396
let action = self.state.provideAuthenticationContext(authContext)
396397
return self.run(action, with: context)

Tests/IntegrationTests/AsyncTests.swift

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,36 @@ final class AsyncPostgresConnectionTests: XCTestCase {
8484
}
8585
}
8686

87+
func testAdditionalParametersTakeEffect() async throws {
88+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
89+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
90+
let eventLoop = eventLoopGroup.next()
91+
92+
let query: PostgresQuery = """
93+
SELECT
94+
current_setting('application_name');
95+
"""
96+
97+
let applicationName = "postgres-nio-test"
98+
var options = PostgresConnection.Configuration.Options()
99+
options.additionalStartupParameters = [
100+
("application_name", applicationName)
101+
]
102+
103+
try await withTestConnection(on: eventLoop, options: options) { connection in
104+
let rows = try await connection.query(query, logger: .psqlTest)
105+
var counter = 0
106+
107+
for try await element in rows.decode(String.self) {
108+
XCTAssertEqual(element, applicationName)
109+
110+
counter += 1
111+
}
112+
113+
XCTAssertGreaterThanOrEqual(counter, 1)
114+
}
115+
}
116+
87117
func testSelectTimeoutWhileLongRunningQuery() async throws {
88118
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
89119
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
@@ -452,11 +482,12 @@ extension XCTestCase {
452482

453483
func withTestConnection<Result>(
454484
on eventLoop: EventLoop,
485+
options: PostgresConnection.Configuration.Options? = nil,
455486
file: StaticString = #filePath,
456487
line: UInt = #line,
457488
_ closure: (PostgresConnection) async throws -> Result
458489
) async throws -> Result {
459-
let connection = try await PostgresConnection.test(on: eventLoop).get()
490+
let connection = try await PostgresConnection.test(on: eventLoop, options: options).get()
460491

461492
do {
462493
let result = try await closure(connection)

Tests/IntegrationTests/Utilities.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,20 @@ extension PostgresConnection {
2424
}
2525
}
2626

27-
static func test(on eventLoop: EventLoop) -> EventLoopFuture<PostgresConnection> {
27+
static func test(on eventLoop: EventLoop, options: Configuration.Options? = nil) -> EventLoopFuture<PostgresConnection> {
2828
let logger = Logger(label: "postgres.connection.test")
29-
let config = PostgresConnection.Configuration(
29+
var config = PostgresConnection.Configuration(
3030
host: env("POSTGRES_HOSTNAME") ?? "localhost",
3131
port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432,
3232
username: env("POSTGRES_USER") ?? "test_username",
3333
password: env("POSTGRES_PASSWORD") ?? "test_password",
3434
database: env("POSTGRES_DB") ?? "test_database",
3535
tls: .disable
3636
)
37-
37+
if let options {
38+
config.options = options
39+
}
40+
3841
return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger)
3942
}
4043

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,48 @@ class PostgresConnectionTests: XCTestCase {
3838
}
3939
}
4040

41+
func testOptionsAreSentOnTheWire() async throws {
42+
let eventLoop = NIOAsyncTestingEventLoop()
43+
let channel = await NIOAsyncTestingChannel(handlers: [
44+
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
45+
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
46+
], loop: eventLoop)
47+
try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432))
48+
49+
let configuration = {
50+
var config = PostgresConnection.Configuration(
51+
establishedChannel: channel,
52+
username: "username",
53+
password: "postgres",
54+
database: "database"
55+
)
56+
config.options.additionalStartupParameters = [
57+
("DateStyle", "ISO, MDY"),
58+
("application_name", "postgres-nio-test"),
59+
("server_encoding", "UTF8"),
60+
("integer_datetimes", "on"),
61+
("client_encoding", "UTF8"),
62+
("TimeZone", "Etc/UTC"),
63+
("is_superuser", "on"),
64+
("server_version", "13.1 (Debian 13.1-1.pgdg100+1)"),
65+
("session_authorization", "postgres"),
66+
("IntervalStyle", "postgres"),
67+
("standard_conforming_strings", "on")
68+
]
69+
return config
70+
}()
71+
72+
async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest)
73+
let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
74+
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false))))
75+
try await channel.writeInbound(PostgresBackendMessage.authentication(.ok))
76+
try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))
77+
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
78+
79+
let connection = try await connectionPromise
80+
try await connection.close()
81+
}
82+
4183
func testSimpleListen() async throws {
4284
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
4385

0 commit comments

Comments
 (0)