Skip to content

Commit 0236666

Browse files
committed
Handle error cases more exhaustively
1 parent 2e878d4 commit 0236666

File tree

9 files changed

+434
-58
lines changed

9 files changed

+434
-58
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,17 @@ extension PostgresConnection {
698698

699699
/// A handle to send
700700
public struct PostgresCopyFromWriter: Sendable {
701+
/// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed.
702+
///
703+
/// The `PostgresCopyFromWriter` should cancel the data transfer.
704+
public struct CopyCancellationError: Error {
705+
/// The error that the backend sent us which cancelled the data transfer.
706+
///
707+
/// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before
708+
/// new data is written by `write`.
709+
let underlyingError: PSQLError
710+
}
711+
701712
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
702713
private let context: NIOLoopBound<ChannelHandlerContext>
703714
private let eventLoop: any EventLoop
@@ -709,16 +720,30 @@ public struct PostgresCopyFromWriter: Sendable {
709720
}
710721

711722
/// Send data for a `COPY ... FROM STDIN` operation to the backend.
723+
///
724+
/// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws
725+
/// a `CopyFailedError`.
712726
public func write(_ byteBuffer: ByteBuffer) async throws {
713-
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
727+
// First, wait that we have a writable buffer. This also throws a `CopyFailedError` in case the backend sent an
728+
// error during the data transfer and thus cannot process any more data.
729+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
714730
if eventLoop.inEventLoop {
715-
self.channelHandler.value.copyData(byteBuffer, context: self.context.value, readyForMoreWriteContinuation: continuation)
731+
self.channelHandler.value.waitForWritableBuffer(context: self.context.value, continuation)
716732
} else {
717733
eventLoop.execute {
718-
self.channelHandler.value.copyData(byteBuffer, context: self.context.value, readyForMoreWriteContinuation: continuation)
734+
self.channelHandler.value.waitForWritableBuffer(context: self.context.value, continuation)
719735
}
720736
}
721737
}
738+
739+
// Run the actual data transfer
740+
if eventLoop.inEventLoop {
741+
self.channelHandler.value.copyData(byteBuffer, context: self.context.value)
742+
} else {
743+
eventLoop.execute {
744+
self.channelHandler.value.copyData(byteBuffer, context: self.context.value)
745+
}
746+
}
722747
}
723748

724749
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
@@ -748,6 +773,17 @@ public struct PostgresCopyFromWriter: Sendable {
748773
}
749774
}
750775
}
776+
777+
/// Send a `Sync` message to the backend.
778+
func sync() {
779+
if eventLoop.inEventLoop {
780+
self.channelHandler.value.sendSync(context: self.context.value)
781+
} else {
782+
eventLoop.execute {
783+
self.channelHandler.value.sendSync(context: self.context.value)
784+
}
785+
}
786+
}
751787
}
752788

753789
extension PostgresConnection {
@@ -777,8 +813,15 @@ extension PostgresConnection {
777813

778814
do {
779815
try await writeData(writer)
816+
} catch let error as PostgresCopyFromWriter.CopyCancellationError {
817+
// If the copy was cancelled because the backend sent us an error, we need to send a `Sync` message to put
818+
// the backend out of the copy mode.
819+
writer.sync()
820+
throw error.underlyingError
780821
} catch {
781-
try await writer.failed(error: error)
822+
// Throw the error from the `writeData` closure instead of the one that Postgres gives us upon receiving the
823+
// `CopyFail` message.
824+
try? await writer.failed(error: error)
782825
throw error
783826
}
784827
try await writer.done()

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ struct ConnectionStateMachine {
8888
case sendParseDescribeBindExecuteSync(PostgresQuery)
8989
case sendBindExecuteSync(PSQLExecuteStatement)
9090
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
91-
/// Fail a query's execution by throwing an error on the given continuation.
92-
case failQueryContinuation(any AnyErrorContinuation, with: PSQLError, cleanupContext: CleanUpContext?)
91+
/// Fail a query's execution by throwing an error on the given continuation. When `sync` is `true`, send a
92+
/// `sync` message to the backend.
93+
case failQueryContinuation(any AnyErrorContinuation, with: PSQLError, cleanupContext: CleanUpContext?, sync: Bool)
9394
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
9495
case succeedQueryContinuation(CheckedContinuation<Void, any Error>)
9596

@@ -124,8 +125,24 @@ struct ConnectionStateMachine {
124125
}
125126

126127
enum ChannelWritabilityChangedAction {
128+
/// No action needs to be taken based on the writability change.
127129
case none
128-
case resumeContinuation(CheckedContinuation<Void, Never>)
130+
131+
/// Resume the given continuation successfully.
132+
case resumeContinuation(CheckedContinuation<Void, any Error>)
133+
}
134+
135+
enum WaitForWritableBufferAction {
136+
/// The channel has backpressure and cannot handle any data right now. We should flush the channel to help
137+
/// relieve backpressure. Once the channel is writable again, this will be communicated via
138+
/// `channelWritabilityChanged`
139+
case waitForBackpressureRelieve
140+
141+
/// Resume the given continuation successfully.
142+
case resumeContinuation(CheckedContinuation<Void, any Error>)
143+
144+
/// Fail the continuation with the given error.
145+
case failContinuation(CheckedContinuation<Void, any Error>, error: any Error)
129146
}
130147

131148
private var state: State
@@ -609,7 +626,7 @@ struct ConnectionStateMachine {
609626
case .executeStatement(_, let promise), .unnamed(_, let promise):
610627
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
611628
case .copyFrom(_, let triggerCopy):
612-
return .failQueryContinuation(triggerCopy, with: psqlErrror, cleanupContext: nil)
629+
return .failQueryContinuation(triggerCopy, with: psqlErrror, cleanupContext: nil, sync: false)
613630
case .prepareStatement(_, _, _, let promise):
614631
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
615632
}
@@ -798,16 +815,20 @@ struct ConnectionStateMachine {
798815
return self.modify(with: action)
799816
}
800817

801-
/// Assuming that the channel to the backend is not writable, wait for the write buffer to become writable again and
802-
/// then resume `continuation`.
803-
mutating func waitForWritableBuffer(continuation: CheckedContinuation<Void, Never>) {
818+
/// Wait fo `channel` to be writable and be able to handle more `CopyData` messages. Resume the given continuation
819+
/// when the channel is able handle more data.
820+
///
821+
/// This fails the continuation with a `PostgresCopyFromWriter.CopyCancellationError` when the server has cancelled
822+
/// the data transfer to indicate that the frontend should not send any more data.
823+
mutating func waitForWritableBuffer(channel: any Channel, continuation: CheckedContinuation<Void, any Error>) -> WaitForWritableBufferAction {
804824
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
805825
preconditionFailure("Copy mode is only supported for extended queries")
806826
}
807827

808828
self.state = .modifying // avoid CoW
809-
queryState.waitForWritableBuffer(continuation: continuation)
829+
let action = queryState.waitForWritableBuffer(channel: channel, continuation: continuation)
810830
self.state = .extendedQuery(queryState, connectionContext)
831+
return action
811832
}
812833

813834
/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
@@ -955,8 +976,8 @@ struct ConnectionStateMachine {
955976
case .failQuery(let promise, with: let error):
956977
return .failQuery(promise, with: error, cleanupContext: cleanupContext)
957978

958-
case .failQueryContinuation(let continuation, with: let error):
959-
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext)
979+
case .failQueryContinuation(let continuation, with: let error, let sync):
980+
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext, sync: sync)
960981

961982
case .forwardStreamError(let error, let read):
962983
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
@@ -1127,9 +1148,9 @@ extension ConnectionStateMachine {
11271148
case .failQuery(let requestContext, with: let error):
11281149
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
11291150
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
1130-
case .failQueryContinuation(let continuation, with: let error):
1151+
case .failQueryContinuation(let continuation, with: let error, let sync):
11311152
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
1132-
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext)
1153+
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext, sync: sync)
11331154
case .succeedQuery(let requestContext, with: let result):
11341155
return .succeedQuery(requestContext, with: result)
11351156
case .succeedQueryContinuation(let continuation):

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct ExtendedQueryStateMachine {
88

99
/// The write channel has backpressure. Once that is relieved, we should resume the given continuation to allow more
1010
/// data to be sent by the client.
11-
case pendingBackpressureRelieve(CheckedContinuation<Void, Never>)
11+
case pendingBackpressureRelieve(CheckedContinuation<Void, any Error>)
1212
}
1313

1414
private enum State {
@@ -50,8 +50,9 @@ struct ExtendedQueryStateMachine {
5050

5151
// --- general actions
5252
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError)
53-
/// Fail a query's execution by throwing an error on the given continuation.
54-
case failQueryContinuation(AnyErrorContinuation, with: PSQLError)
53+
/// Fail a query's execution by throwing an error on the given continuation. If `sync` is `true`, send a `sync`
54+
/// message to the backend to put it out of the copy mode.
55+
case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool)
5556
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
5657
case succeedQueryContinuation(CheckedContinuation<Void, any Error>)
5758

@@ -142,14 +143,14 @@ struct ExtendedQueryStateMachine {
142143
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
143144
return .failQuery(eventLoopPromise, with: .queryCancelled)
144145
case .copyFrom(_, let triggerCopy):
145-
return .failQueryContinuation(triggerCopy, with: .queryCancelled)
146+
return .failQueryContinuation(triggerCopy, with: .queryCancelled, sync: false)
146147
case .prepareStatement(_, _, _, let eventLoopPromise):
147148
return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled)
148149
}
149150
case .copyingData:
150151
return .sendCopyFailed(message: "Query cancelled")
151152
case .copyingFinished(let continuation):
152-
return .failQueryContinuation(continuation, with: .queryCancelled)
153+
return .failQueryContinuation(continuation, with: .queryCancelled, sync: true)
153154

154155
case .streaming(let columns, var streamStateMachine):
155156
precondition(!self.isCancelled)
@@ -392,24 +393,40 @@ struct ExtendedQueryStateMachine {
392393
}
393394
}
394395

395-
/// Assuming that the channel to the backend is not writable, wait for the write buffer to become writable again and
396-
/// then resume `continuation`.
397-
mutating func waitForWritableBuffer(continuation: CheckedContinuation<Void, Never>) {
396+
/// Wait fo `channel` to be writable and be able to handle more `CopyData` messages. Resume the given continuation
397+
/// when the channel is able handle more data.
398+
///
399+
/// This fails the continuation with a `PostgresCopyFromWriter.CopyCancellationError` when the server has cancelled
400+
/// the data transfer to indicate that the frontend should not send any more data.
401+
mutating func waitForWritableBuffer(
402+
channel: any Channel,
403+
continuation: CheckedContinuation<Void, any Error>
404+
) -> ConnectionStateMachine.WaitForWritableBufferAction {
405+
if case .error(let error) = self.state {
406+
return .failContinuation(continuation, error: PostgresCopyFromWriter.CopyCancellationError(underlyingError: error))
407+
}
398408
guard case .copyingData(let copyingSubstate) = self.state else {
399409
preconditionFailure("Must be in copy mode to copy data")
400410
}
401411
guard case .readyToSend = copyingSubstate else {
402412
preconditionFailure("Not ready to send data")
403413
}
414+
if channel.isWritable {
415+
return .resumeContinuation(continuation)
416+
}
404417
return avoidingStateMachineCoW { state in
405418
// Even if the buffer isn't writable, we write the current chunk of data to it. We just don't resume
406419
// the continuation. This will prevent more writes from happening to build up more write backpressure.
407420
state = .copyingData(.pendingBackpressureRelieve(continuation))
421+
return .waitForBackpressureRelieve
408422
}
409423
}
410424

411425
/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
412426
mutating func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) -> Action {
427+
if case .error(let error) = self.state {
428+
return .failQueryContinuation(continuation, with: error, sync: true)
429+
}
413430
guard case .copyingData = self.state else {
414431
preconditionFailure("Must be in copy mode to send CopyDone")
415432
}
@@ -421,6 +438,9 @@ struct ExtendedQueryStateMachine {
421438

422439
/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
423440
mutating func sendCopyFailed(message: String, continuation: CheckedContinuation<Void, any Error>) -> Action {
441+
if case .error(let error) = self.state {
442+
return .failQueryContinuation(continuation, with: error, sync: true)
443+
}
424444
guard case .copyingData = self.state else {
425445
preconditionFailure("Must be in copy mode to send CopyFail")
426446
}
@@ -627,7 +647,7 @@ struct ExtendedQueryStateMachine {
627647
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
628648
return .failQuery(eventLoopPromise, with: error)
629649
case .copyFrom(_, let triggerCopy):
630-
return .failQueryContinuation(triggerCopy, with: error)
650+
return .failQueryContinuation(triggerCopy, with: error, sync: false)
631651
case .prepareStatement(_, _, _, let eventLoopPromise):
632652
return .failPreparedStatementCreation(eventLoopPromise, with: error)
633653
}
@@ -637,7 +657,7 @@ struct ExtendedQueryStateMachine {
637657
return .evaluateErrorAtConnectionLevel(error)
638658
case .copyingFinished(let continuation):
639659
self.state = .error(error)
640-
return .failQueryContinuation(continuation, with: error)
660+
return .failQueryContinuation(continuation, with: error, sync: true)
641661
case .drain:
642662
self.state = .error(error)
643663
return .evaluateErrorAtConnectionLevel(error)

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,26 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
171171
self.run(action, with: context)
172172
}
173173

174-
/// Send a `CopyData` message to the backend using the given data.
174+
/// Wait for the channel to be writable and be able to handle more `CopyData` messages. Resume the given
175+
/// continuation when the channel is able handle more data.
175176
///
176-
/// `readyForMoreWriteContinuation` is resumed when the channel is able to handle more data to be written to it.
177-
func copyData(_ data: ByteBuffer, context: ChannelHandlerContext, readyForMoreWriteContinuation: CheckedContinuation<Void, Never>) {
178-
self.encoder.copyData(data: data)
179-
if context.channel.isWritable {
180-
readyForMoreWriteContinuation.resume()
181-
} else {
182-
self.state.waitForWritableBuffer(continuation: readyForMoreWriteContinuation)
183-
context.flush()
177+
/// This fails the continuation with a `PostgresCopyFromWriter.CopyCancellationError` when the server has cancelled
178+
/// the data transfer to indicate that the frontend should not send any more data.
179+
func waitForWritableBuffer(context: ChannelHandlerContext, _ continuation: CheckedContinuation<Void, any Error>) {
180+
let action = self.state.waitForWritableBuffer(channel: context.channel, continuation: continuation)
181+
switch action {
182+
case .waitForBackpressureRelieve:
183+
context.channel.flush()
184+
case .resumeContinuation(let continuation):
185+
continuation.resume()
186+
case .failContinuation(_, error: let error):
187+
continuation.resume(throwing: error)
184188
}
189+
}
190+
191+
/// Send a `CopyData` message to the backend using the given data.
192+
func copyData(_ data: ByteBuffer, context: ChannelHandlerContext) {
193+
self.encoder.copyData(data: data)
185194
context.write(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
186195
}
187196

@@ -197,6 +206,12 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
197206
self.run(action, with: context)
198207
}
199208

209+
/// Send a `Sync` message to the backend.
210+
func sendSync(context: ChannelHandlerContext) {
211+
self.encoder.sync()
212+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
213+
}
214+
200215
func channelReadComplete(context: ChannelHandlerContext) {
201216
let action = self.state.channelReadComplete()
202217
self.run(action, with: context)
@@ -398,11 +413,15 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
398413
if let cleanupContext = cleanupContext {
399414
self.closeConnectionAndCleanup(cleanupContext, context: context)
400415
}
401-
case .failQueryContinuation(let continuation, with: let error, let cleanupContext):
402-
continuation.resume(throwing: error)
416+
case .failQueryContinuation(let continuation, with: let error, let cleanupContext, let sync):
417+
if sync {
418+
self.encoder.sync()
419+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
420+
}
403421
if let cleanupContext = cleanupContext {
404422
self.closeConnectionAndCleanup(cleanupContext, context: context)
405423
}
424+
continuation.resume(throwing: error)
406425
case .triggerCopyData(let triggerCopy):
407426
let writer = PostgresCopyFromWriter(handler: self, context: context, eventLoop: eventLoop)
408427
triggerCopy.resume(returning: writer)

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,10 @@ final class IntegrationTests: XCTestCase {
401401
}
402402
}, logger: .psqlTest)
403403
let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) }
404-
XCTAssertEqual(rows.count, 2)
404+
guard rows.count == 2 else {
405+
XCTFail("Expected 2 columns, received \(rows.count)")
406+
return
407+
}
405408
XCTAssertEqual(rows[0].0, 1)
406409
XCTAssertEqual(rows[0].1, "Alice")
407410
XCTAssertEqual(rows[1].0, 42)

0 commit comments

Comments
 (0)