Skip to content

Commit e70add9

Browse files
authored
Merge branch 'main' into mmbm-row-seq-expose-metadata
2 parents 3933fc4 + f2a6394 commit e70add9

File tree

5 files changed

+167
-0
lines changed

5 files changed

+167
-0
lines changed

Sources/PostgresNIO/New/PostgresQuery.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,16 @@ public struct PostgresBindings: Sendable, Hashable {
172172
try self.append(value, context: .default)
173173
}
174174

175+
@inlinable
176+
public mutating func append<Value: PostgresThrowingDynamicTypeEncodable>(_ value: Optional<Value>) throws {
177+
switch value {
178+
case .none:
179+
self.appendNull()
180+
case let .some(value):
181+
try self.append(value)
182+
}
183+
}
184+
175185
@inlinable
176186
public mutating func append<Value: PostgresThrowingDynamicTypeEncodable, JSONEncoder: PostgresJSONEncoder>(
177187
_ value: Value,
@@ -181,11 +191,34 @@ public struct PostgresBindings: Sendable, Hashable {
181191
self.metadata.append(.init(value: value, protected: true))
182192
}
183193

194+
@inlinable
195+
public mutating func append<Value: PostgresThrowingDynamicTypeEncodable, JSONEncoder: PostgresJSONEncoder>(
196+
_ value: Optional<Value>,
197+
context: PostgresEncodingContext<JSONEncoder>
198+
) throws {
199+
switch value {
200+
case .none:
201+
self.appendNull()
202+
case let .some(value):
203+
try self.append(value, context: context)
204+
}
205+
}
206+
184207
@inlinable
185208
public mutating func append<Value: PostgresDynamicTypeEncodable>(_ value: Value) {
186209
self.append(value, context: .default)
187210
}
188211

212+
@inlinable
213+
public mutating func append<Value: PostgresDynamicTypeEncodable>(_ value: Optional<Value>) {
214+
switch value {
215+
case .none:
216+
self.appendNull()
217+
case let .some(value):
218+
self.append(value)
219+
}
220+
}
221+
189222
@inlinable
190223
public mutating func append<Value: PostgresDynamicTypeEncodable, JSONEncoder: PostgresJSONEncoder>(
191224
_ value: Value,
@@ -195,6 +228,19 @@ public struct PostgresBindings: Sendable, Hashable {
195228
self.metadata.append(.init(value: value, protected: true))
196229
}
197230

231+
@inlinable
232+
public mutating func append<Value: PostgresDynamicTypeEncodable, JSONEncoder: PostgresJSONEncoder>(
233+
_ value: Optional<Value>,
234+
context: PostgresEncodingContext<JSONEncoder>
235+
) {
236+
switch value {
237+
case .none:
238+
self.appendNull()
239+
case let .some(value):
240+
self.append(value, context: context)
241+
}
242+
}
243+
198244
@inlinable
199245
mutating func appendUnprotected<Value: PostgresEncodable, JSONEncoder: PostgresJSONEncoder>(
200246
_ value: Value,

Sources/PostgresNIO/Pool/ConnectionFactory.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ final class ConnectionFactory: Sendable {
8989
connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout)
9090
connectionConfig.options.tlsServerName = config.options.tlsServerName
9191
connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData
92+
connectionConfig.options.additionalStartupParameters = config.options.additionalStartupParameters
9293

9394
return connectionConfig
9495
}

Sources/PostgresNIO/Pool/PostgresClient.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service {
106106
/// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default).
107107
public var requireBackendKeyData: Bool = true
108108

109+
/// Additional parameters to send to the server on startup. The name value pairs are added to the initial
110+
/// startup message that the client sends to the server.
111+
public var additionalStartupParameters: [(String, String)] = []
112+
109113
/// The minimum number of connections that the client shall keep open at any time, even if there is no
110114
/// demand. Default to `0`.
111115
///

Tests/IntegrationTests/AsyncTests.swift

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,87 @@ final class AsyncPostgresConnectionTests: XCTestCase {
525525
XCTFail("Unexpected error: \(String(describing: error))")
526526
}
527527
}
528+
529+
static let preparedStatementWithOptionalTestTable = "AsyncTestPreparedStatementWithOptionalTestTable"
530+
func testPreparedStatementWithOptionalBinding() async throws {
531+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
532+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
533+
let eventLoop = eventLoopGroup.next()
534+
535+
struct InsertPreparedStatement: PostgresPreparedStatement {
536+
static let name = "INSERT-AsyncTestPreparedStatementWithOptionalTestTable"
537+
538+
static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" (uuid) VALUES ($1);"#
539+
typealias Row = ()
540+
541+
var uuid: UUID?
542+
543+
func makeBindings() -> PostgresBindings {
544+
var bindings = PostgresBindings()
545+
bindings.append(self.uuid)
546+
return bindings
547+
}
548+
549+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
550+
()
551+
}
552+
}
553+
554+
struct SelectPreparedStatement: PostgresPreparedStatement {
555+
static let name = "SELECT-AsyncTestPreparedStatementWithOptionalTestTable"
556+
557+
static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" WHERE id <= $1;"#
558+
typealias Row = (Int, UUID?)
559+
560+
var id: Int
561+
562+
func makeBindings() -> PostgresBindings {
563+
var bindings = PostgresBindings()
564+
bindings.append(self.id)
565+
return bindings
566+
}
567+
568+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
569+
try row.decode((Int, UUID?).self)
570+
}
571+
}
572+
573+
do {
574+
try await withTestConnection(on: eventLoop) { connection in
575+
try await connection.query("""
576+
CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementWithOptionalTestTable)" (
577+
id SERIAL PRIMARY KEY,
578+
uuid UUID
579+
)
580+
""",
581+
logger: .psqlTest
582+
)
583+
584+
_ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest)
585+
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
586+
_ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest)
587+
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
588+
_ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest)
589+
590+
let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest)
591+
var counter = 0
592+
for try await (id, uuid) in rows {
593+
Logger.psqlTest.info("Received row", metadata: [
594+
"id": "\(id)", "uuid": "\(String(describing: uuid))"
595+
])
596+
counter += 1
597+
}
598+
599+
try await connection.query("""
600+
DROP TABLE "\(unescaped: Self.preparedStatementWithOptionalTestTable)";
601+
""",
602+
logger: .psqlTest
603+
)
604+
}
605+
} catch {
606+
XCTFail("Unexpected error: \(String(describing: error))")
607+
}
608+
}
528609
}
529610

530611
extension XCTestCase {

Tests/IntegrationTests/PostgresClientTests.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,41 @@ final class PostgresClientTests: XCTestCase {
4343
}
4444
}
4545

46+
func testApplicationNameIsForwardedCorrectly() async throws {
47+
var mlogger = Logger(label: "test")
48+
mlogger.logLevel = .debug
49+
let logger = mlogger
50+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8)
51+
self.addTeardownBlock {
52+
try await eventLoopGroup.shutdownGracefully()
53+
}
54+
55+
var clientConfig = PostgresClient.Configuration.makeTestConfiguration()
56+
let applicationName = "postgres_nio_test_run"
57+
clientConfig.options.additionalStartupParameters = [("application_name", applicationName)]
58+
let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger)
59+
60+
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
61+
taskGroup.addTask {
62+
await client.run()
63+
}
64+
65+
let rows = try await client.query("select * from pg_stat_activity;");
66+
var applicationNameFound = 0
67+
for try await row in rows {
68+
let randomAccessRow = row.makeRandomAccess()
69+
if try randomAccessRow["application_name"].decode(String?.self) == applicationName {
70+
applicationNameFound += 1
71+
}
72+
}
73+
74+
XCTAssertGreaterThanOrEqual(applicationNameFound, 1)
75+
76+
taskGroup.cancelAll()
77+
}
78+
}
79+
80+
4681
func testQueryDirectly() async throws {
4782
var mlogger = Logger(label: "test")
4883
mlogger.logLevel = .debug

0 commit comments

Comments
 (0)