Skip to content

Commit 6c85b3e

Browse files
committed
move to the suggested way of doing things
1 parent 830c3e7 commit 6c85b3e

File tree

6 files changed

+205
-61
lines changed

6 files changed

+205
-61
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
.DS_Store
22
/.build
3+
/.index-build
34
/Packages
45
/*.xcodeproj
56
DerivedData

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,43 @@ extension PostgresConnection {
438438
}
439439
}
440440

441+
// use this for queries where you want to consume the rows.
442+
// we can use the `consume` scope to better ensure structured concurrency when consuming the rows.
443+
public func query<Result>(
444+
_ query: PostgresQuery,
445+
logger: Logger,
446+
file: String = #fileID,
447+
line: Int = #line,
448+
_ consume: (PostgresRowSequence) async throws -> Result
449+
) async throws -> (Result, PostgresQueryMetadata) {
450+
var logger = logger
451+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
452+
453+
guard query.binds.count <= Int(UInt16.max) else {
454+
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
455+
}
456+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
457+
let context = ExtendedQueryContext(
458+
query: query,
459+
logger: logger,
460+
promise: promise
461+
)
462+
463+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
464+
465+
do {
466+
let rowSequence = try await promise.futureResult.map({ $0.asyncSequence() }).get()
467+
let result = try await consume(rowSequence)
468+
let metadata = try await rowSequence.drainAndCollectMetadata()
469+
return (result, metadata)
470+
} catch var error as PSQLError {
471+
error.file = file
472+
error.line = line
473+
error.query = query
474+
throw error // rethrow with more metadata
475+
}
476+
}
477+
441478
/// Start listening for a channel
442479
public func listen(_ channel: String) async throws -> PostgresNotificationSequence {
443480
let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed)
@@ -531,6 +568,42 @@ extension PostgresConnection {
531568
}
532569
}
533570

571+
// use this for queries where you want to consume the rows.
572+
// we can use the `consume` scope to better ensure structured concurrency when consuming the rows.
573+
@discardableResult
574+
public func execute(
575+
_ query: PostgresQuery,
576+
logger: Logger,
577+
file: String = #fileID,
578+
line: Int = #line
579+
) async throws -> PostgresQueryMetadata {
580+
var logger = logger
581+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
582+
583+
guard query.binds.count <= Int(UInt16.max) else {
584+
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
585+
}
586+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
587+
let context = ExtendedQueryContext(
588+
query: query,
589+
logger: logger,
590+
promise: promise
591+
)
592+
593+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
594+
595+
do {
596+
let rowSequence = try await promise.futureResult.map({ $0.asyncSequence() }).get()
597+
let metadata = try await rowSequence.drainAndCollectMetadata()
598+
return metadata
599+
} catch var error as PSQLError {
600+
error.file = file
601+
error.line = line
602+
error.query = query
603+
throw error // rethrow with more metadata
604+
}
605+
}
606+
534607
#if compiler(>=6.0)
535608
/// Puts the connection into an open transaction state, for the provided `closure`'s lifetime.
536609
///

Sources/PostgresNIO/New/PSQLRowStream.swift

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,63 @@ final class PSQLRowStream: @unchecked Sendable {
276276
return self.eventLoop.makeFailedFuture(error)
277277
}
278278
}
279-
279+
280+
// MARK: Drain on EventLoop
281+
282+
func drain() -> EventLoopFuture<Void> {
283+
if self.eventLoop.inEventLoop {
284+
return self.drain0()
285+
} else {
286+
return self.eventLoop.flatSubmit {
287+
self.drain0()
288+
}
289+
}
290+
}
291+
292+
private func drain0() -> EventLoopFuture<Void> {
293+
self.eventLoop.preconditionInEventLoop()
294+
295+
switch self.downstreamState {
296+
case .waitingForConsumer(let bufferState):
297+
switch bufferState {
298+
case .streaming(var buffer, let dataSource):
299+
let promise = self.eventLoop.makePromise(of: Void.self)
300+
301+
buffer.removeAll()
302+
self.downstreamState = .iteratingRows(onRow: { _ in }, promise, dataSource)
303+
// immediately request more
304+
dataSource.request(for: self)
305+
306+
return promise.futureResult
307+
308+
case .finished(_, let summary):
309+
self.downstreamState = .consumed(.success(summary))
310+
return self.eventLoop.makeSucceededVoidFuture()
311+
312+
case .failure(let error):
313+
self.downstreamState = .consumed(.failure(error))
314+
return self.eventLoop.makeFailedFuture(error)
315+
}
316+
case .asyncSequence(let consumer, let dataSource, _):
317+
consumer.finish()
318+
319+
let promise = self.eventLoop.makePromise(of: Void.self)
320+
321+
self.downstreamState = .iteratingRows(onRow: { _ in }, promise, dataSource)
322+
// immediately request more
323+
dataSource.request(for: self)
324+
325+
return promise.futureResult
326+
case .consumed(.success):
327+
// already drained
328+
return self.eventLoop.makeSucceededVoidFuture()
329+
case .consumed(let .failure(error)):
330+
return self.eventLoop.makeFailedFuture(error)
331+
default:
332+
preconditionFailure("Invalid state: \(self.downstreamState)")
333+
}
334+
}
335+
280336
internal func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) {
281337
self.logger.debug("Notice Received", metadata: [
282338
.notice: "\(notice)"

Sources/PostgresNIO/New/PostgresRowSequence.swift

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ public struct PostgresRowSequence: AsyncSequence, Sendable {
3030
columns: self.columns
3131
)
3232
}
33+
34+
func drain() {
35+
self.backing
36+
}
3337
}
3438

3539
extension PostgresRowSequence {
@@ -64,43 +68,11 @@ extension PostgresRowSequence {
6468
extension PostgresRowSequence.AsyncIterator: Sendable {}
6569

6670
extension PostgresRowSequence {
67-
/// Collect and return all rows.
68-
/// - Returns: The rows.
69-
public func collect() async throws -> [PostgresRow] {
70-
var result = [PostgresRow]()
71-
for try await row in self {
72-
result.append(row)
73-
}
74-
return result
75-
}
76-
77-
/// Collect and return all rows, alongside the query metadata.
78-
/// - Returns: The query metadata and the rows.
79-
public func collectWithMetadata() async throws -> (metadata: PostgresQueryMetadata, rows: [PostgresRow]) {
80-
let rows = try await self.collect()
81-
guard let metadata = PostgresQueryMetadata(string: self.rowStream.commandTag) else {
82-
throw PSQLError.invalidCommandTag(self.rowStream.commandTag)
83-
}
84-
return (metadata, rows)
85-
}
86-
87-
/// Consumes all rows and returns the query metadata.
88-
///
89-
/// If you don't need the returned query metadata, just use the for-try-await-loop syntax:
90-
/// ```swift
91-
/// for try await row in myRowSequence {
92-
/// /// Process each row
93-
/// }
94-
/// ```
95-
///
96-
/// - Parameter onRow: Processes each row.
97-
/// - Returns: The query metadata.
98-
public func consume(
99-
onRow: @Sendable (PostgresRow) throws -> ()
100-
) async throws -> PostgresQueryMetadata {
101-
for try await row in self {
102-
try onRow(row)
103-
}
71+
/// Collects the query metadata.
72+
/// Should be called after the sequence is consumed, otherwise throws `PSQLError`.
73+
/// - Returns: The metadata.
74+
func drainAndCollectMetadata() async throws -> PostgresQueryMetadata {
75+
try await self.rowStream.drain().get()
10476
guard let metadata = PostgresQueryMetadata(string: self.rowStream.commandTag) else {
10577
throw PSQLError.invalidCommandTag(self.rowStream.commandTag)
10678
}

Tests/IntegrationTests/AsyncTests.swift

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {
4747
}
4848
}
4949

50-
func testSelect10kRowsAndConsume() async throws {
50+
func testSelect10kRowsWithMetadata() async throws {
5151
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
5252
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
5353
let eventLoop = eventLoopGroup.next()
@@ -56,24 +56,28 @@ final class AsyncPostgresConnectionTests: XCTestCase {
5656
let end = 10000
5757

5858
try await withTestConnection(on: eventLoop) { connection in
59-
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)
60-
61-
let counter = ManagedAtomic(0)
62-
let metadata = try await rows.consume { row in
63-
let element = try row.decode(Int.self)
64-
let newCounter = counter.wrappingIncrementThenLoad(ordering: .relaxed)
65-
XCTAssertEqual(element, newCounter)
59+
let (result, metadata) = try await connection.query(
60+
"SELECT generate_series(\(start), \(end));",
61+
logger: .psqlTest
62+
) { rows in
63+
var counter = 0
64+
for try await row in rows {
65+
let element = try row.decode(Int.self)
66+
XCTAssertEqual(element, counter + 1)
67+
counter += 1
68+
}
69+
return counter
6670
}
6771

6872
XCTAssertEqual(metadata.command, "SELECT")
6973
XCTAssertEqual(metadata.oid, nil)
70-
XCTAssertEqual(metadata.rows, 10000)
74+
XCTAssertEqual(metadata.rows, end)
7175

72-
XCTAssertEqual(counter.load(ordering: .relaxed), end)
76+
XCTAssertEqual(result, end)
7377
}
7478
}
7579

76-
func testSelect10kRowsAndCollect() async throws {
80+
func testSelectRowsWithMetadataNotConsumedAtAll() async throws {
7781
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
7882
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
7983
let eventLoop = eventLoopGroup.next()
@@ -82,20 +86,56 @@ final class AsyncPostgresConnectionTests: XCTestCase {
8286
let end = 10000
8387

8488
try await withTestConnection(on: eventLoop) { connection in
85-
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)
86-
let (metadata, elements) = try await rows.collectWithMetadata()
87-
var counter = 0
88-
for row in elements {
89-
let element = try row.decode(Int.self)
90-
XCTAssertEqual(element, counter + 1)
91-
counter += 1
92-
}
89+
let (_, metadata) = try await connection.query(
90+
"SELECT generate_series(\(start), \(end));",
91+
logger: .psqlTest
92+
) { _ in }
9393

9494
XCTAssertEqual(metadata.command, "SELECT")
9595
XCTAssertEqual(metadata.oid, nil)
96-
XCTAssertEqual(metadata.rows, 10000)
96+
XCTAssertEqual(metadata.rows, end)
97+
}
98+
}
9799

98-
XCTAssertEqual(counter, end)
100+
func testSelectRowsWithMetadataNotFullyConsumed() async throws {
101+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
102+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
103+
let eventLoop = eventLoopGroup.next()
104+
105+
try await withTestConnection(on: eventLoop) { connection in
106+
do {
107+
_ = try await connection.query(
108+
"SELECT generate_series(1, 10000);",
109+
logger: .psqlTest
110+
) { rows in
111+
for try await _ in rows { break }
112+
}
113+
XCTFail("Expected a failure")
114+
} catch is CancellationError {
115+
// Expected
116+
} catch {
117+
XCTFail("Expected 'CancellationError', got: \(String(reflecting: error))")
118+
}
119+
}
120+
}
121+
122+
func testExecuteRowsWithMetadata() async throws {
123+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
124+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
125+
let eventLoop = eventLoopGroup.next()
126+
127+
let start = 1
128+
let end = 10000
129+
130+
try await withTestConnection(on: eventLoop) { connection in
131+
let metadata = try await connection.execute(
132+
"SELECT generate_series(\(start), \(end));",
133+
logger: .psqlTest
134+
)
135+
136+
XCTAssertEqual(metadata.command, "SELECT")
137+
XCTAssertEqual(metadata.oid, nil)
138+
XCTAssertEqual(metadata.rows, end)
99139
}
100140
}
101141

@@ -294,8 +334,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {
294334
unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)",
295335
binds: binds
296336
)
297-
let result = try await connection.query(insertionQuery, logger: .psqlTest)
298-
let metadata = try await result.collectWithMetadata().metadata
337+
let metadata = try await connection.execute(insertionQuery, logger: .psqlTest)
299338
XCTAssertEqual(metadata.rows, rowsCount)
300339

301340
let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1")

docker-compose.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ x-shared-config: &shared_config
1010
- 5432:5432
1111

1212
services:
13+
psql-17:
14+
image: postgres:17
15+
<<: *shared_config
1316
psql-16:
1417
image: postgres:16
1518
<<: *shared_config

0 commit comments

Comments
 (0)