Skip to content

Commit 39d42a6

Browse files
committed
Merge branch 'mmbm-row-seq-expose-metadata' into main-updated
2 parents 94c1284 + 68084f9 commit 39d42a6

File tree

3 files changed

+104
-16
lines changed

3 files changed

+104
-16
lines changed

Sources/PostgresNIO/New/PSQLRowStream.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ final class PSQLRowStream: @unchecked Sendable {
4444
}
4545

4646
internal let rowDescription: [RowDescription.Column]
47-
private let lookupTable: [String: Int]
47+
internal let lookupTable: [String: Int]
4848
private var downstreamState: DownstreamState
4949

5050
init(
@@ -114,7 +114,7 @@ final class PSQLRowStream: @unchecked Sendable {
114114
self.downstreamState = .consumed(.failure(error))
115115
}
116116

117-
return PostgresRowSequence(producer.sequence, lookupTable: self.lookupTable, columns: self.rowDescription)
117+
return PostgresRowSequence(producer.sequence, rowStream: self)
118118
}
119119

120120
func demand() {

Sources/PostgresNIO/New/PostgresRowSequence.swift

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@ public struct PostgresRowSequence: AsyncSequence, Sendable {
99

1010
typealias BackingSequence = NIOThrowingAsyncSequenceProducer<DataRow, Error, AdaptiveRowBuffer, PSQLRowStream>
1111

12-
let backing: BackingSequence
13-
let lookupTable: [String: Int]
14-
let columns: [RowDescription.Column]
12+
private let backing: BackingSequence
13+
private let rowStream: PSQLRowStream
14+
var lookupTable: [String: Int] {
15+
self.rowStream.lookupTable
16+
}
17+
var columns: [RowDescription.Column] {
18+
self.rowStream.rowDescription
19+
}
1520

16-
init(_ backing: BackingSequence, lookupTable: [String: Int], columns: [RowDescription.Column]) {
21+
init(_ backing: BackingSequence, rowStream: PSQLRowStream) {
1722
self.backing = backing
18-
self.lookupTable = lookupTable
19-
self.columns = columns
23+
self.rowStream = rowStream
2024
}
2125

2226
public func makeAsyncIterator() -> AsyncIterator {
@@ -60,13 +64,48 @@ extension PostgresRowSequence {
6064
extension PostgresRowSequence.AsyncIterator: Sendable {}
6165

6266
extension PostgresRowSequence {
67+
/// Collect and return all rows.
68+
/// - Returns: The rows.
6369
public func collect() async throws -> [PostgresRow] {
6470
var result = [PostgresRow]()
6571
for try await row in self {
6672
result.append(row)
6773
}
6874
return result
6975
}
76+
77+
/// Collect and return all rows, alongside the query metadata.
78+
/// - Returns: The query metadata and the rows.
79+
public func collectWithMetadata() async throws -> (metadata: PostgresQueryMetadata, rows: [PostgresRow]) {
80+
let rows = try await self.collect()
81+
guard let metadata = PostgresQueryMetadata(string: self.rowStream.commandTag) else {
82+
throw PSQLError.invalidCommandTag(self.rowStream.commandTag)
83+
}
84+
return (metadata, rows)
85+
}
86+
87+
/// Consumes all rows and returns the query metadata.
88+
///
89+
/// If you don't need the returned query metadata, just use the for-try-await-loop syntax:
90+
/// ```swift
91+
/// for try await row in myRowSequence {
92+
/// /// Process each row
93+
/// }
94+
/// ```
95+
///
96+
/// - Parameter onRow: Processes each row.
97+
/// - Returns: The query metadata.
98+
public func consume(
99+
onRow: @Sendable (PostgresRow) throws -> ()
100+
) async throws -> PostgresQueryMetadata {
101+
for try await row in self {
102+
try onRow(row)
103+
}
104+
guard let metadata = PostgresQueryMetadata(string: self.rowStream.commandTag) else {
105+
throw PSQLError.invalidCommandTag(self.rowStream.commandTag)
106+
}
107+
return metadata
108+
}
70109
}
71110

72111
struct AdaptiveRowBuffer: NIOAsyncSequenceProducerBackPressureStrategy {

Tests/IntegrationTests/AsyncTests.swift

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import Atomics
12
import Logging
23
import XCTest
34
import PostgresNIO
@@ -46,6 +47,58 @@ final class AsyncPostgresConnectionTests: XCTestCase {
4647
}
4748
}
4849

50+
func testSelect10kRowsAndConsume() async throws {
51+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
52+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
53+
let eventLoop = eventLoopGroup.next()
54+
55+
let start = 1
56+
let end = 10000
57+
58+
try await withTestConnection(on: eventLoop) { connection in
59+
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)
60+
61+
let counter = ManagedAtomic(0)
62+
let metadata = try await rows.consume { row in
63+
let element = try row.decode(Int.self)
64+
let newCounter = counter.wrappingIncrementThenLoad(ordering: .relaxed)
65+
XCTAssertEqual(element, newCounter)
66+
}
67+
68+
XCTAssertEqual(metadata.command, "SELECT")
69+
XCTAssertEqual(metadata.oid, nil)
70+
XCTAssertEqual(metadata.rows, 10000)
71+
72+
XCTAssertEqual(counter.load(ordering: .relaxed), end)
73+
}
74+
}
75+
76+
func testSelect10kRowsAndCollect() async throws {
77+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
78+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
79+
let eventLoop = eventLoopGroup.next()
80+
81+
let start = 1
82+
let end = 10000
83+
84+
try await withTestConnection(on: eventLoop) { connection in
85+
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)
86+
let (metadata, elements) = try await rows.collectWithMetadata()
87+
var counter = 0
88+
for row in elements {
89+
let element = try row.decode(Int.self)
90+
XCTAssertEqual(element, counter + 1)
91+
counter += 1
92+
}
93+
94+
XCTAssertEqual(metadata.command, "SELECT")
95+
XCTAssertEqual(metadata.oid, nil)
96+
XCTAssertEqual(metadata.rows, 10000)
97+
98+
XCTAssertEqual(counter, end)
99+
}
100+
}
101+
49102
func testSelectActiveConnection() async throws {
50103
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
51104
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
@@ -207,7 +260,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {
207260

208261
try await withTestConnection(on: eventLoop) { connection in
209262
// Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257
210-
// Max columns limit is 1664, so we will only make 5 * 257 columns which is less
263+
// Max columns limit appears to be ~1600, so we will only make 5 * 257 columns which is less
211264
// Then we will insert 3 * 17 rows
212265
// In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings
213266
// If the test is successful, it means Postgres supports UInt16.max bindings
@@ -241,13 +294,9 @@ final class AsyncPostgresConnectionTests: XCTestCase {
241294
unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)",
242295
binds: binds
243296
)
244-
try await connection.query(insertionQuery, logger: .psqlTest)
245-
246-
let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1")
247-
let countRows = try await connection.query(countQuery, logger: .psqlTest)
248-
var countIterator = countRows.makeAsyncIterator()
249-
let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default)
250-
XCTAssertEqual(rowsCount, insertedRowsCount)
297+
let result = try await connection.query(insertionQuery, logger: .psqlTest)
298+
let metadata = try await result.collectWithMetadata().metadata
299+
XCTAssertEqual(metadata.rows, rowsCount)
251300

252301
let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1")
253302
try await connection.query(dropQuery, logger: .psqlTest)

0 commit comments

Comments
 (0)