Skip to content

Commit 44390f0

Browse files
authored
Merge branch 'main' into mmbm-fix-promise-leak
2 parents 3c287b9 + 9f84290 commit 44390f0

File tree

8 files changed

+114
-48
lines changed

8 files changed

+114
-48
lines changed

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ struct ExtendedQueryStateMachine {
1010
case parameterDescriptionReceived(ExtendedQueryContext)
1111
case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column])
1212
case noDataMessageReceived(ExtendedQueryContext)
13-
13+
case emptyQueryResponseReceived
14+
1415
/// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is
1516
/// used after receiving a `bindComplete` message
1617
case bindCompleteReceived(ExtendedQueryContext)
@@ -123,7 +124,7 @@ struct ExtendedQueryStateMachine {
123124
return .forwardStreamError(.queryCancelled, read: true)
124125
}
125126

126-
case .commandComplete, .error, .drain:
127+
case .commandComplete, .emptyQueryResponseReceived, .error, .drain:
127128
// the stream has already finished.
128129
return .wait
129130

@@ -230,6 +231,7 @@ struct ExtendedQueryStateMachine {
230231
.messagesSent,
231232
.parseCompleteReceived,
232233
.parameterDescriptionReceived,
234+
.emptyQueryResponseReceived,
233235
.bindCompleteReceived,
234236
.streaming,
235237
.drain,
@@ -269,6 +271,7 @@ struct ExtendedQueryStateMachine {
269271
.parseCompleteReceived,
270272
.parameterDescriptionReceived,
271273
.noDataMessageReceived,
274+
.emptyQueryResponseReceived,
272275
.rowDescriptionReceived,
273276
.bindCompleteReceived,
274277
.commandComplete,
@@ -286,7 +289,7 @@ struct ExtendedQueryStateMachine {
286289
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
287290
return self.avoidingStateMachineCoW { state -> Action in
288291
state = .commandComplete(commandTag: commandTag)
289-
let result = QueryResult(value: .noRows(commandTag), logger: context.logger)
292+
let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger)
290293
return .succeedQuery(eventLoopPromise, with: result)
291294
}
292295

@@ -310,6 +313,7 @@ struct ExtendedQueryStateMachine {
310313
.parseCompleteReceived,
311314
.parameterDescriptionReceived,
312315
.noDataMessageReceived,
316+
.emptyQueryResponseReceived,
313317
.rowDescriptionReceived,
314318
.commandComplete,
315319
.error:
@@ -320,7 +324,22 @@ struct ExtendedQueryStateMachine {
320324
}
321325

322326
mutating func emptyQueryResponseReceived() -> Action {
323-
preconditionFailure("Unimplemented")
327+
guard case .bindCompleteReceived(let queryContext) = self.state else {
328+
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
329+
}
330+
331+
switch queryContext.query {
332+
case .unnamed(_, let eventLoopPromise),
333+
.executeStatement(_, let eventLoopPromise):
334+
return self.avoidingStateMachineCoW { state -> Action in
335+
state = .emptyQueryResponseReceived
336+
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
337+
return .succeedQuery(eventLoopPromise, with: result)
338+
}
339+
340+
case .prepareStatement(_, _, _, _):
341+
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
342+
}
324343
}
325344

326345
mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action {
@@ -337,7 +356,7 @@ struct ExtendedQueryStateMachine {
337356
return self.setAndFireError(error)
338357
case .streaming, .drain:
339358
return self.setAndFireError(error)
340-
case .commandComplete:
359+
case .commandComplete, .emptyQueryResponseReceived:
341360
return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage)))
342361
case .error:
343362
preconditionFailure("""
@@ -383,6 +402,7 @@ struct ExtendedQueryStateMachine {
383402
.parseCompleteReceived,
384403
.parameterDescriptionReceived,
385404
.noDataMessageReceived,
405+
.emptyQueryResponseReceived,
386406
.rowDescriptionReceived,
387407
.bindCompleteReceived:
388408
preconditionFailure("Requested to consume next row without anything going on.")
@@ -406,6 +426,7 @@ struct ExtendedQueryStateMachine {
406426
.parseCompleteReceived,
407427
.parameterDescriptionReceived,
408428
.noDataMessageReceived,
429+
.emptyQueryResponseReceived,
409430
.rowDescriptionReceived,
410431
.bindCompleteReceived:
411432
return .wait
@@ -450,6 +471,7 @@ struct ExtendedQueryStateMachine {
450471
}
451472
case .initialized,
452473
.commandComplete,
474+
.emptyQueryResponseReceived,
453475
.drain,
454476
.error:
455477
// we already have the complete stream received, now we are waiting for a
@@ -496,7 +518,7 @@ struct ExtendedQueryStateMachine {
496518
return .forwardStreamError(error, read: true)
497519
}
498520

499-
case .commandComplete, .error:
521+
case .commandComplete, .emptyQueryResponseReceived, .error:
500522
preconditionFailure("""
501523
This state must not be reached. If the query `.isComplete`, the
502524
ConnectionStateMachine must not send any further events to the substate machine.
@@ -508,7 +530,7 @@ struct ExtendedQueryStateMachine {
508530

509531
var isComplete: Bool {
510532
switch self.state {
511-
case .commandComplete, .error:
533+
case .commandComplete, .emptyQueryResponseReceived, .error:
512534
return true
513535

514536
case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _):

Sources/PostgresNIO/New/PSQLRowStream.swift

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import Logging
33

44
struct QueryResult {
55
enum Value: Equatable {
6-
case noRows(String)
6+
case noRows(PSQLRowStream.StatementSummary)
77
case rowDescription([RowDescription.Column])
88
}
99

@@ -16,25 +16,30 @@ struct QueryResult {
1616
final class PSQLRowStream: @unchecked Sendable {
1717
private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer<DataRow, Error, AdaptiveRowBuffer, PSQLRowStream>.Source
1818

19+
enum StatementSummary: Equatable {
20+
case tag(String)
21+
case emptyResponse
22+
}
23+
1924
enum Source {
2025
case stream([RowDescription.Column], PSQLRowsDataSource)
21-
case noRows(Result<String, Error>)
26+
case noRows(Result<StatementSummary, Error>)
2227
}
2328

2429
let eventLoop: EventLoop
2530
let logger: Logger
26-
31+
2732
private enum BufferState {
2833
case streaming(buffer: CircularBuffer<DataRow>, dataSource: PSQLRowsDataSource)
29-
case finished(buffer: CircularBuffer<DataRow>, commandTag: String)
34+
case finished(buffer: CircularBuffer<DataRow>, summary: StatementSummary)
3035
case failure(Error)
3136
}
32-
37+
3338
private enum DownstreamState {
3439
case waitingForConsumer(BufferState)
3540
case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise<Void>, PSQLRowsDataSource)
3641
case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource)
37-
case consumed(Result<String, Error>)
42+
case consumed(Result<StatementSummary, Error>)
3843
case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ())
3944
}
4045

@@ -52,9 +57,9 @@ final class PSQLRowStream: @unchecked Sendable {
5257
case .stream(let rowDescription, let dataSource):
5358
self.rowDescription = rowDescription
5459
bufferState = .streaming(buffer: .init(), dataSource: dataSource)
55-
case .noRows(.success(let commandTag)):
60+
case .noRows(.success(let summary)):
5661
self.rowDescription = []
57-
bufferState = .finished(buffer: .init(), commandTag: commandTag)
62+
bufferState = .finished(buffer: .init(), summary: summary)
5863
case .noRows(.failure(let error)):
5964
self.rowDescription = []
6065
bufferState = .failure(error)
@@ -98,12 +103,12 @@ final class PSQLRowStream: @unchecked Sendable {
98103
self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish)
99104
self.executeActionBasedOnYieldResult(yieldResult, source: dataSource)
100105

101-
case .finished(let buffer, let commandTag):
106+
case .finished(let buffer, let summary):
102107
_ = source.yield(contentsOf: buffer)
103108
source.finish()
104109
onFinish()
105-
self.downstreamState = .consumed(.success(commandTag))
106-
110+
self.downstreamState = .consumed(.success(summary))
111+
107112
case .failure(let error):
108113
source.finish(error)
109114
self.downstreamState = .consumed(.failure(error))
@@ -190,12 +195,12 @@ final class PSQLRowStream: @unchecked Sendable {
190195
dataSource.request(for: self)
191196
return promise.futureResult
192197

193-
case .finished(let buffer, let commandTag):
198+
case .finished(let buffer, let summary):
194199
let rows = buffer.map {
195200
PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription)
196201
}
197202

198-
self.downstreamState = .consumed(.success(commandTag))
203+
self.downstreamState = .consumed(.success(summary))
199204
return self.eventLoop.makeSucceededFuture(rows)
200205

201206
case .failure(let error):
@@ -247,8 +252,8 @@ final class PSQLRowStream: @unchecked Sendable {
247252
}
248253

249254
return promise.futureResult
250-
251-
case .finished(let buffer, let commandTag):
255+
256+
case .finished(let buffer, let summary):
252257
do {
253258
for data in buffer {
254259
let row = PostgresRow(
@@ -259,7 +264,7 @@ final class PSQLRowStream: @unchecked Sendable {
259264
try onRow(row)
260265
}
261266

262-
self.downstreamState = .consumed(.success(commandTag))
267+
self.downstreamState = .consumed(.success(summary))
263268
return self.eventLoop.makeSucceededVoidFuture()
264269
} catch {
265270
self.downstreamState = .consumed(.failure(error))
@@ -292,7 +297,7 @@ final class PSQLRowStream: @unchecked Sendable {
292297

293298
case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
294299
preconditionFailure("How can new rows be received, if an end was already signalled?")
295-
300+
296301
case .iteratingRows(let onRow, let promise, let dataSource):
297302
do {
298303
for data in newRows {
@@ -347,25 +352,25 @@ final class PSQLRowStream: @unchecked Sendable {
347352
private func receiveEnd(_ commandTag: String) {
348353
switch self.downstreamState {
349354
case .waitingForConsumer(.streaming(buffer: let buffer, _)):
350-
self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag))
351-
352-
case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
355+
self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, summary: .tag(commandTag)))
356+
357+
case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)):
353358
preconditionFailure("How can we get another end, if an end was already signalled?")
354359

355360
case .iteratingRows(_, let promise, _):
356-
self.downstreamState = .consumed(.success(commandTag))
361+
self.downstreamState = .consumed(.success(.tag(commandTag)))
357362
promise.succeed(())
358363

359364
case .waitingForAll(let rows, let promise, _):
360-
self.downstreamState = .consumed(.success(commandTag))
365+
self.downstreamState = .consumed(.success(.tag(commandTag)))
361366
promise.succeed(rows)
362367

363368
case .asyncSequence(let source, _, let onFinish):
364-
self.downstreamState = .consumed(.success(commandTag))
369+
self.downstreamState = .consumed(.success(.tag(commandTag)))
365370
source.finish()
366371
onFinish()
367372

368-
case .consumed:
373+
case .consumed(.success(.tag)), .consumed(.failure):
369374
break
370375
}
371376
}
@@ -375,7 +380,7 @@ final class PSQLRowStream: @unchecked Sendable {
375380
case .waitingForConsumer(.streaming):
376381
self.downstreamState = .waitingForConsumer(.failure(error))
377382

378-
case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
383+
case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)):
379384
preconditionFailure("How can we get another end, if an end was already signalled?")
380385

381386
case .iteratingRows(_, let promise, _):
@@ -391,7 +396,7 @@ final class PSQLRowStream: @unchecked Sendable {
391396
consumer.finish(error)
392397
onFinish()
393398

394-
case .consumed:
399+
case .consumed(.success(.tag)), .consumed(.failure):
395400
break
396401
}
397402
}
@@ -413,10 +418,15 @@ final class PSQLRowStream: @unchecked Sendable {
413418
}
414419

415420
var commandTag: String {
416-
guard case .consumed(.success(let commandTag)) = self.downstreamState else {
421+
guard case .consumed(.success(let consumed)) = self.downstreamState else {
417422
preconditionFailure("commandTag may only be called if all rows have been consumed")
418423
}
419-
return commandTag
424+
switch consumed {
425+
case .tag(let tag):
426+
return tag
427+
case .emptyResponse:
428+
return ""
429+
}
420430
}
421431
}
422432

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
574574
)
575575
self.rowStream = rows
576576

577-
case .noRows(let commandTag):
577+
case .noRows(let summary):
578578
rows = PSQLRowStream(
579-
source: .noRows(.success(commandTag)),
579+
source: .noRows(.success(summary)),
580580
eventLoop: context.channel.eventLoop,
581581
logger: result.logger
582582
)

Sources/PostgresNIO/PostgresDatabase+Query.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,7 @@ public struct PostgresQueryMetadata: Sendable {
7373

7474
init?(string: String) {
7575
let parts = string.split(separator: " ")
76-
guard parts.count >= 1 else {
77-
return nil
78-
}
79-
switch parts[0] {
76+
switch parts.first {
8077
case "INSERT":
8178
// INSERT oid rows
8279
guard parts.count == 3 else {

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,25 @@ final class IntegrationTests: XCTestCase {
123123
XCTAssertEqual(foo, "hello")
124124
}
125125

126+
func testQueryNothing() throws {
127+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
128+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
129+
let eventLoop = eventLoopGroup.next()
130+
131+
var conn: PostgresConnection?
132+
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
133+
defer { XCTAssertNoThrow(try conn?.close().wait()) }
134+
135+
var _result: PostgresQueryResult?
136+
XCTAssertNoThrow(_result = try conn?.query("""
137+
-- Some comments
138+
""", logger: .psqlTest).wait())
139+
140+
let result = try XCTUnwrap(_result)
141+
XCTAssertEqual(result.rows, [])
142+
XCTAssertEqual(result.metadata.command, "")
143+
}
144+
126145
func testDecodeIntegers() {
127146
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
128147
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }

Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
2020
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
2121
XCTAssertEqual(state.noDataReceived(), .wait)
2222
XCTAssertEqual(state.bindCompleteReceived(), .wait)
23-
XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows("DELETE 1"), logger: logger)))
23+
XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger)))
2424
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
2525
}
2626

@@ -77,7 +77,25 @@ class ExtendedQueryStateMachineTests: XCTestCase {
7777
XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2"))
7878
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
7979
}
80-
80+
81+
func testExtendedQueryWithNoQuery() {
82+
var state = ConnectionStateMachine.readyForQuery()
83+
84+
let logger = Logger.psqlTest
85+
let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self)
86+
promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all.
87+
let query: PostgresQuery = "-- some comments"
88+
let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise)
89+
90+
XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query))
91+
XCTAssertEqual(state.parseCompleteReceived(), .wait)
92+
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
93+
XCTAssertEqual(state.noDataReceived(), .wait)
94+
XCTAssertEqual(state.bindCompleteReceived(), .wait)
95+
XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger)))
96+
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
97+
}
98+
8199
func testReceiveTotallyUnexpectedMessageInQuery() {
82100
var state = ConnectionStateMachine.readyForQuery()
83101

0 commit comments

Comments
 (0)