@@ -88,7 +88,23 @@ struct ConnectionStateMachine {
8888 case sendParseDescribeBindExecuteSync( PostgresQuery )
8989 case sendBindExecuteSync( PSQLExecuteStatement )
9090 case failQuery( EventLoopPromise < PSQLRowStream > , with: PSQLError , cleanupContext: CleanUpContext ? )
91+ /// Fail a query's execution by throwing an error on the given continuation.
92+ case failQueryContinuation( any AnyErrorContinuation , with: PSQLError , cleanupContext: CleanUpContext ? )
9193 case succeedQuery( EventLoopPromise < PSQLRowStream > , with: QueryResult )
94+ case succeedQueryContinuation( CheckedContinuation < Void , any Error > )
95+
96+ /// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation.
97+ ///
98+ /// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state
99+ /// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling
100+ /// `PostgresChannelHandler.copyDone` or ``PostgresChannelHandler.copyFailed``.
101+ case triggerCopyData( CheckedContinuation < PostgresCopyFromWriter , any Error > )
102+
103+ /// Send a `CopyDone` message to the backend, followed by a `Sync`.
104+ case sendCopyDone
105+
106+ /// Send a `CopyFail` message to the backend with the given error message.
107+ case sendCopyFailed( message: String )
92108
93109 // --- streaming actions
94110 // actions if query has requested next row but we are waiting for backend
@@ -587,6 +603,8 @@ struct ConnectionStateMachine {
587603 switch queryContext. query {
588604 case . executeStatement( _, let promise) , . unnamed( _, let promise) :
589605 return . failQuery( promise, with: psqlErrror, cleanupContext: nil )
606+ case . copyFrom( _, let triggerCopy) :
607+ return . failQueryContinuation( triggerCopy, with: psqlErrror, cleanupContext: nil )
590608 case . prepareStatement( _, _, _, let promise) :
591609 return . failPreparedStatementCreation( promise, with: psqlErrror, cleanupContext: nil )
592610 }
@@ -660,6 +678,15 @@ struct ConnectionStateMachine {
660678 preconditionFailure ( " Invalid state: \( self . state) " )
661679 }
662680 }
681+
682+ mutating func channelWritabilityChanged( isWritable: Bool ) {
683+ guard case . extendedQuery( var queryState, let connectionContext) = state else {
684+ return
685+ }
686+ self . state = . modifying // avoid CoW
687+ queryState. channelWritabilityChanged ( isWritable: isWritable)
688+ self . state = . extendedQuery( queryState, connectionContext)
689+ }
663690
664691 // MARK: - Running Queries -
665692
@@ -751,6 +778,55 @@ struct ConnectionStateMachine {
751778 self . state = . extendedQuery( queryState, connectionContext)
752779 return self . modify ( with: action)
753780 }
781+
782+ mutating func copyInResponseReceived(
783+ _ copyInResponse: PostgresBackendMessage . CopyInResponseMessage
784+ ) -> ConnectionAction {
785+ guard case . extendedQuery( var queryState, let connectionContext) = self . state, !queryState. isComplete else {
786+ return self . closeConnectionAndCleanup ( . unexpectedBackendMessage( . emptyQueryResponse) )
787+ }
788+
789+ self . state = . modifying // avoid CoW
790+ let action = queryState. copyInResponseReceived ( copyInResponse)
791+ self . state = . extendedQuery( queryState, connectionContext)
792+ return self . modify ( with: action)
793+ }
794+
795+ /// Assuming that the channel to the backend is not writable, wait for the write buffer to become writable again and
796+ /// then resume `continuation`.
797+ mutating func waitForWritableBuffer( continuation: CheckedContinuation < Void , Never > ) {
798+ guard case . extendedQuery( var queryState, let connectionContext) = self . state else {
799+ preconditionFailure ( " Copy mode is only supported for extended queries " )
800+ }
801+
802+ self . state = . modifying // avoid CoW
803+ queryState. waitForWritableBuffer ( continuation: continuation)
804+ self . state = . extendedQuery( queryState, connectionContext)
805+ }
806+
807+ /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
808+ mutating func sendCopyDone( continuation: CheckedContinuation < Void , any Error > ) -> ConnectionAction {
809+ guard case . extendedQuery( var queryState, let connectionContext) = self . state else {
810+ preconditionFailure ( " Copy mode is only supported for extended queries " )
811+ }
812+
813+ self . state = . modifying // avoid CoW
814+ let action = queryState. sendCopyDone ( continuation: continuation)
815+ self . state = . extendedQuery( queryState, connectionContext)
816+ return self . modify ( with: action)
817+ }
818+
819+ /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
820+ mutating func sendCopyFail( message: String , continuation: CheckedContinuation < Void , any Error > ) -> ConnectionAction {
821+ guard case . extendedQuery( var queryState, let connectionContext) = self . state else {
822+ preconditionFailure ( " Copy mode is only supported for extended queries " )
823+ }
824+
825+ self . state = . modifying // avoid CoW
826+ let action = queryState. sendCopyFailed ( message: message, continuation: continuation)
827+ self . state = . extendedQuery( queryState, connectionContext)
828+ return self . modify ( with: action)
829+ }
754830
755831 mutating func emptyQueryResponseReceived( ) -> ConnectionAction {
756832 guard case . extendedQuery( var queryState, let connectionContext) = self . state, !queryState. isComplete else {
@@ -860,14 +936,21 @@ struct ConnectionStateMachine {
860936 . forwardRows,
861937 . forwardStreamComplete,
862938 . wait,
863- . read:
939+ . read,
940+ . triggerCopyData,
941+ . sendCopyDone,
942+ . sendCopyFailed,
943+ . succeedQueryContinuation:
864944 preconditionFailure ( " Invalid query state machine action in state: \( self . state) , action: \( action) " )
865945
866946 case . evaluateErrorAtConnectionLevel:
867947 return . closeConnectionAndCleanup( cleanupContext)
868948
869- case . failQuery( let queryContext, with: let error) :
870- return . failQuery( queryContext, with: error, cleanupContext: cleanupContext)
949+ case . failQuery( let promise, with: let error) :
950+ return . failQuery( promise, with: error, cleanupContext: cleanupContext)
951+
952+ case . failQueryContinuation( let continuation, with: let error) :
953+ return . failQueryContinuation( continuation, with: error, cleanupContext: cleanupContext)
871954
872955 case . forwardStreamError( let error, let read) :
873956 return . forwardStreamError( error, read: read, cleanupContext: cleanupContext)
@@ -1038,8 +1121,19 @@ extension ConnectionStateMachine {
10381121 case . failQuery( let requestContext, with: let error) :
10391122 let cleanupContext = self . setErrorAndCreateCleanupContextIfNeeded ( error)
10401123 return . failQuery( requestContext, with: error, cleanupContext: cleanupContext)
1124+ case . failQueryContinuation( let continuation, with: let error) :
1125+ let cleanupContext = self . setErrorAndCreateCleanupContextIfNeeded ( error)
1126+ return . failQueryContinuation( continuation, with: error, cleanupContext: cleanupContext)
10411127 case . succeedQuery( let requestContext, with: let result) :
10421128 return . succeedQuery( requestContext, with: result)
1129+ case . succeedQueryContinuation( let continuation) :
1130+ return . succeedQueryContinuation( continuation)
1131+ case . triggerCopyData( let triggerCopy) :
1132+ return . triggerCopyData( triggerCopy)
1133+ case . sendCopyDone:
1134+ return . sendCopyDone
1135+ case . sendCopyFailed( message: let message) :
1136+ return . sendCopyFailed( message: message)
10431137 case . forwardRows( let buffer) :
10441138 return . forwardRows( buffer)
10451139 case . forwardStreamComplete( let buffer, let commandTag) :
0 commit comments