Skip to content

Commit 87b9868

Browse files
committed
WIP: Implement COPY … FROM STDIN
This is still WIP. To do items include: - [ ] Test the various error cases, I have mostly focused on the success case so far - [ ] Test the backpressure support - [ ] Change the public API to accept the table + columns to copy into as well as options so that we can build the `COPY` query instead of letting the user write it - [ ] Add an API that allows binary transfer of data
1 parent c17db2f commit 87b9868

File tree

13 files changed

+544
-27
lines changed

13 files changed

+544
-27
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,100 @@ extension PostgresConnection {
694694
}
695695
}
696696

697+
// MARK: Copy from
698+
699+
fileprivate extension EventLoop {
700+
/// If we are on the given event loop, execute `task` immediately. Otherwise schedule it for execution.
701+
func executeImmediatelyOrSchedule(_ task: @Sendable @escaping () -> Void) {
702+
if inEventLoop {
703+
return task()
704+
}
705+
return execute(task)
706+
}
707+
}
708+
709+
/// A handle to send
710+
public struct PostgresCopyFromWriter: Sendable {
711+
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
712+
private let context: NIOLoopBound<ChannelHandlerContext>
713+
private let eventLoop: any EventLoop
714+
715+
struct NotWritableError: Error, CustomStringConvertible {
716+
var description = "No data must be written to `PostgresCopyFromWriter` after it has sent a CopyDone or CopyFail message, ie. after the closure producing the copy data has finished"
717+
}
718+
719+
init(handler: PostgresChannelHandler, context: ChannelHandlerContext, eventLoop: any EventLoop) {
720+
self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop)
721+
self.context = NIOLoopBound(context, eventLoop: eventLoop)
722+
self.eventLoop = eventLoop
723+
}
724+
725+
/// Send data for a `COPY ... FROM STDIN` operation to the backend.
726+
public func write(_ byteBuffer: ByteBuffer) async throws {
727+
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
728+
eventLoop.executeImmediatelyOrSchedule {
729+
self.channelHandler.value.copyData(byteBuffer, context: self.context.value, readyForMoreWriteContinuation: continuation)
730+
}
731+
}
732+
}
733+
734+
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
735+
/// the backend.
736+
func done() async throws {
737+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
738+
eventLoop.executeImmediatelyOrSchedule {
739+
self.channelHandler.value.sendCopyDone(continuation: continuation, context: self.context.value)
740+
}
741+
}
742+
}
743+
744+
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to
745+
/// the backend.
746+
func failed(error: any Error) async throws {
747+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
748+
eventLoop.executeImmediatelyOrSchedule {
749+
self.channelHandler.value.sendCopyFailed(message: "\(error)", continuation: continuation, context: self.context.value)
750+
}
751+
}
752+
}
753+
}
754+
755+
extension PostgresConnection {
756+
// TODO: Instead of an arbitrary query, make this a structured data structure.
757+
// TODO: Write doc comment
758+
public func copyFrom(
759+
_ query: PostgresQuery,
760+
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void,
761+
logger: Logger,
762+
file: String = #fileID,
763+
line: Int = #line
764+
) async throws {
765+
var logger = logger
766+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
767+
guard query.binds.count <= Int(UInt16.max) else {
768+
throw PSQLError(code: .tooManyParameters, query: query)
769+
}
770+
771+
let writer = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<PostgresCopyFromWriter, any Error>) in
772+
let context = ExtendedQueryContext(
773+
copyFromQuery: query,
774+
triggerCopy: continuation,
775+
logger: logger
776+
)
777+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
778+
}
779+
780+
do {
781+
try await writeData(writer)
782+
} catch {
783+
try await writer.failed(error: error)
784+
throw error
785+
}
786+
try await writer.done()
787+
}
788+
789+
}
790+
697791
// MARK: PostgresDatabase conformance
698792

699793
extension PostgresConnection: PostgresDatabase {

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,23 @@ struct ConnectionStateMachine {
8888
case sendParseDescribeBindExecuteSync(PostgresQuery)
8989
case sendBindExecuteSync(PSQLExecuteStatement)
9090
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
91+
/// Fail a query's execution by throwing an error on the given continuation.
92+
case failQueryContinuation(any AnyErrorContinuation, with: PSQLError, cleanupContext: CleanUpContext?)
9193
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
94+
case succeedQueryContinuation(CheckedContinuation<Void, any Error>)
95+
96+
/// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation.
97+
///
98+
/// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state
99+
/// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling
100+
/// `PostgresChannelHandler.copyDone` or ``PostgresChannelHandler.copyFailed``.
101+
case triggerCopyData(CheckedContinuation<PostgresCopyFromWriter, any Error>)
102+
103+
/// Send a `CopyDone` message to the backend, followed by a `Sync`.
104+
case sendCopyDone
105+
106+
/// Send a `CopyFail` message to the backend with the given error message.
107+
case sendCopyFailed(message: String)
92108

93109
// --- streaming actions
94110
// actions if query has requested next row but we are waiting for backend
@@ -587,6 +603,8 @@ struct ConnectionStateMachine {
587603
switch queryContext.query {
588604
case .executeStatement(_, let promise), .unnamed(_, let promise):
589605
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
606+
case .copyFrom(_, let triggerCopy):
607+
return .failQueryContinuation(triggerCopy, with: psqlErrror, cleanupContext: nil)
590608
case .prepareStatement(_, _, _, let promise):
591609
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
592610
}
@@ -660,6 +678,15 @@ struct ConnectionStateMachine {
660678
preconditionFailure("Invalid state: \(self.state)")
661679
}
662680
}
681+
682+
mutating func channelWritabilityChanged(isWritable: Bool) {
683+
guard case .extendedQuery(var queryState, let connectionContext) = state else {
684+
return
685+
}
686+
self.state = .modifying // avoid CoW
687+
queryState.channelWritabilityChanged(isWritable: isWritable)
688+
self.state = .extendedQuery(queryState, connectionContext)
689+
}
663690

664691
// MARK: - Running Queries -
665692

@@ -751,6 +778,55 @@ struct ConnectionStateMachine {
751778
self.state = .extendedQuery(queryState, connectionContext)
752779
return self.modify(with: action)
753780
}
781+
782+
mutating func copyInResponseReceived(
783+
_ copyInResponse: PostgresBackendMessage.CopyInResponseMessage
784+
) -> ConnectionAction {
785+
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
786+
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse))
787+
}
788+
789+
self.state = .modifying // avoid CoW
790+
let action = queryState.copyInResponseReceived(copyInResponse)
791+
self.state = .extendedQuery(queryState, connectionContext)
792+
return self.modify(with: action)
793+
}
794+
795+
/// Assuming that the channel to the backend is not writable, wait for the write buffer to become writable again and
796+
/// then resume `continuation`.
797+
mutating func waitForWritableBuffer(continuation: CheckedContinuation<Void, Never>) {
798+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
799+
preconditionFailure("Copy mode is only supported for extended queries")
800+
}
801+
802+
self.state = .modifying // avoid CoW
803+
queryState.waitForWritableBuffer(continuation: continuation)
804+
self.state = .extendedQuery(queryState, connectionContext)
805+
}
806+
807+
/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
808+
mutating func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
809+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
810+
preconditionFailure("Copy mode is only supported for extended queries")
811+
}
812+
813+
self.state = .modifying // avoid CoW
814+
let action = queryState.sendCopyDone(continuation: continuation)
815+
self.state = .extendedQuery(queryState, connectionContext)
816+
return self.modify(with: action)
817+
}
818+
819+
/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
820+
mutating func sendCopyFail(message: String, continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
821+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
822+
preconditionFailure("Copy mode is only supported for extended queries")
823+
}
824+
825+
self.state = .modifying // avoid CoW
826+
let action = queryState.sendCopyFailed(message: message, continuation: continuation)
827+
self.state = .extendedQuery(queryState, connectionContext)
828+
return self.modify(with: action)
829+
}
754830

755831
mutating func emptyQueryResponseReceived() -> ConnectionAction {
756832
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
@@ -860,14 +936,21 @@ struct ConnectionStateMachine {
860936
.forwardRows,
861937
.forwardStreamComplete,
862938
.wait,
863-
.read:
939+
.read,
940+
.triggerCopyData,
941+
.sendCopyDone,
942+
.sendCopyFailed,
943+
.succeedQueryContinuation:
864944
preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)")
865945

866946
case .evaluateErrorAtConnectionLevel:
867947
return .closeConnectionAndCleanup(cleanupContext)
868948

869-
case .failQuery(let queryContext, with: let error):
870-
return .failQuery(queryContext, with: error, cleanupContext: cleanupContext)
949+
case .failQuery(let promise, with: let error):
950+
return .failQuery(promise, with: error, cleanupContext: cleanupContext)
951+
952+
case .failQueryContinuation(let continuation, with: let error):
953+
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext)
871954

872955
case .forwardStreamError(let error, let read):
873956
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
@@ -1038,8 +1121,19 @@ extension ConnectionStateMachine {
10381121
case .failQuery(let requestContext, with: let error):
10391122
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
10401123
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
1124+
case .failQueryContinuation(let continuation, with: let error):
1125+
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
1126+
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext)
10411127
case .succeedQuery(let requestContext, with: let result):
10421128
return .succeedQuery(requestContext, with: result)
1129+
case .succeedQueryContinuation(let continuation):
1130+
return .succeedQueryContinuation(continuation)
1131+
case .triggerCopyData(let triggerCopy):
1132+
return .triggerCopyData(triggerCopy)
1133+
case .sendCopyDone:
1134+
return .sendCopyDone
1135+
case .sendCopyFailed(message: let message):
1136+
return .sendCopyFailed(message: message)
10431137
case .forwardRows(let buffer):
10441138
return .forwardRows(buffer)
10451139
case .forwardStreamComplete(let buffer, let commandTag):

0 commit comments

Comments
 (0)