Skip to content

Commit 85d189c

Browse files
authored
Run queries directly on PostgresClient (#456)
1 parent 6433f6d commit 85d189c

File tree

3 files changed

+104
-12
lines changed

3 files changed

+104
-12
lines changed

Sources/PostgresNIO/New/PSQLRowStream.swift

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ final class PSQLRowStream: @unchecked Sendable {
3535
case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise<Void>, PSQLRowsDataSource)
3636
case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource)
3737
case consumed(Result<String, Error>)
38-
case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource)
38+
case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ())
3939
}
4040

4141
internal let rowDescription: [RowDescription.Column]
@@ -75,7 +75,7 @@ final class PSQLRowStream: @unchecked Sendable {
7575

7676
// MARK: Async Sequence
7777

78-
func asyncSequence() -> PostgresRowSequence {
78+
func asyncSequence(onFinish: @escaping @Sendable () -> () = {}) -> PostgresRowSequence {
7979
self.eventLoop.preconditionInEventLoop()
8080

8181
guard case .waitingForConsumer(let bufferState) = self.downstreamState else {
@@ -95,13 +95,13 @@ final class PSQLRowStream: @unchecked Sendable {
9595
switch bufferState {
9696
case .streaming(let bufferedRows, let dataSource):
9797
let yieldResult = source.yield(contentsOf: bufferedRows)
98-
self.downstreamState = .asyncSequence(source, dataSource)
99-
98+
self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish)
10099
self.executeActionBasedOnYieldResult(yieldResult, source: dataSource)
101100

102101
case .finished(let buffer, let commandTag):
103102
_ = source.yield(contentsOf: buffer)
104103
source.finish()
104+
onFinish()
105105
self.downstreamState = .consumed(.success(commandTag))
106106

107107
case .failure(let error):
@@ -130,7 +130,7 @@ final class PSQLRowStream: @unchecked Sendable {
130130
case .consumed:
131131
break
132132

133-
case .asyncSequence(_, let dataSource):
133+
case .asyncSequence(_, let dataSource, _):
134134
dataSource.request(for: self)
135135
}
136136
}
@@ -147,9 +147,10 @@ final class PSQLRowStream: @unchecked Sendable {
147147

148148
private func cancel0() {
149149
switch self.downstreamState {
150-
case .asyncSequence(_, let dataSource):
150+
case .asyncSequence(_, let dataSource, let onFinish):
151151
self.downstreamState = .consumed(.failure(CancellationError()))
152152
dataSource.cancel(for: self)
153+
onFinish()
153154

154155
case .consumed:
155156
return
@@ -320,7 +321,7 @@ final class PSQLRowStream: @unchecked Sendable {
320321
// immediately request more
321322
dataSource.request(for: self)
322323

323-
case .asyncSequence(let consumer, let source):
324+
case .asyncSequence(let consumer, let source, _):
324325
let yieldResult = consumer.yield(contentsOf: newRows)
325326
self.executeActionBasedOnYieldResult(yieldResult, source: source)
326327

@@ -359,10 +360,11 @@ final class PSQLRowStream: @unchecked Sendable {
359360
self.downstreamState = .consumed(.success(commandTag))
360361
promise.succeed(rows)
361362

362-
case .asyncSequence(let source, _):
363-
source.finish()
363+
case .asyncSequence(let source, _, let onFinish):
364364
self.downstreamState = .consumed(.success(commandTag))
365-
365+
source.finish()
366+
onFinish()
367+
366368
case .consumed:
367369
break
368370
}
@@ -384,9 +386,10 @@ final class PSQLRowStream: @unchecked Sendable {
384386
self.downstreamState = .consumed(.failure(error))
385387
promise.fail(error)
386388

387-
case .asyncSequence(let consumer, _):
388-
consumer.finish(error)
389+
case .asyncSequence(let consumer, _, let onFinish):
389390
self.downstreamState = .consumed(.failure(error))
391+
consumer.finish(error)
392+
onFinish()
390393

391394
case .consumed:
392395
break

Sources/PostgresNIO/Pool/PostgresClient.swift

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,58 @@ public final class PostgresClient: Sendable {
290290
return try await closure(connection)
291291
}
292292

293+
/// Run a query on the Postgres server the client is connected to.
294+
///
295+
/// - Parameters:
296+
/// - query: The ``PostgresQuery`` to run
297+
/// - logger: The `Logger` to log into for the query
298+
/// - file: The file, the query was started in. Used for better error reporting.
299+
/// - line: The line, the query was started in. Used for better error reporting.
300+
/// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result.
301+
/// The sequence be discarded.
302+
@discardableResult
303+
public func query(
304+
_ query: PostgresQuery,
305+
logger: Logger,
306+
file: String = #fileID,
307+
line: Int = #line
308+
) async throws -> PostgresRowSequence {
309+
do {
310+
guard query.binds.count <= Int(UInt16.max) else {
311+
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
312+
}
313+
314+
let connection = try await self.leaseConnection()
315+
316+
var logger = logger
317+
logger[postgresMetadataKey: .connectionID] = "\(connection.id)"
318+
319+
let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self)
320+
let context = ExtendedQueryContext(
321+
query: query,
322+
logger: logger,
323+
promise: promise
324+
)
325+
326+
connection.channel.write(HandlerTask.extendedQuery(context), promise: nil)
327+
328+
promise.futureResult.whenFailure { _ in
329+
self.pool.releaseConnection(connection)
330+
}
331+
332+
return try await promise.futureResult.map {
333+
$0.asyncSequence(onFinish: {
334+
self.pool.releaseConnection(connection)
335+
})
336+
}.get()
337+
} catch var error as PSQLError {
338+
error.file = file
339+
error.line = line
340+
error.query = query
341+
throw error // rethrow with more metadata
342+
}
343+
}
344+
293345
/// The client's run method. Users must call this function in order to start the client's background task processing
294346
/// like creating and destroying connections and running timers.
295347
///

Tests/IntegrationTests/PostgresClientTests.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,43 @@ final class PostgresClientTests: XCTestCase {
4141
taskGroup.cancelAll()
4242
}
4343
}
44+
45+
func testQueryDirectly() async throws {
46+
var mlogger = Logger(label: "test")
47+
mlogger.logLevel = .debug
48+
let logger = mlogger
49+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8)
50+
self.addTeardownBlock {
51+
try await eventLoopGroup.shutdownGracefully()
52+
}
53+
54+
let clientConfig = PostgresClient.Configuration.makeTestConfiguration()
55+
let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger)
56+
57+
await withThrowingTaskGroup(of: Void.self) { taskGroup in
58+
taskGroup.addTask {
59+
await client.run()
60+
}
61+
62+
for i in 0..<10000 {
63+
taskGroup.addTask {
64+
do {
65+
try await client.query("SELECT 1", logger: logger)
66+
logger.info("Success", metadata: ["run": "\(i)"])
67+
} catch {
68+
XCTFail("Unexpected error: \(error)")
69+
}
70+
}
71+
}
72+
73+
for _ in 0..<10000 {
74+
_ = await taskGroup.nextResult()!
75+
}
76+
77+
taskGroup.cancelAll()
78+
}
79+
}
80+
4481
}
4582

4683
@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)

0 commit comments

Comments
 (0)