Skip to content

Commit 2e878d4

Browse files
committed
Address review comments
1 parent a0cc056 commit 2e878d4

File tree

4 files changed

+35
-25
lines changed

4 files changed

+35
-25
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -696,26 +696,12 @@ extension PostgresConnection {
696696

697697
// MARK: Copy from
698698

699-
fileprivate extension EventLoop {
700-
/// If we are on the given event loop, execute `task` immediately. Otherwise schedule it for execution.
701-
func executeImmediatelyOrSchedule(_ task: @Sendable @escaping () -> Void) {
702-
if inEventLoop {
703-
return task()
704-
}
705-
return execute(task)
706-
}
707-
}
708-
709699
/// A handle to send
710700
public struct PostgresCopyFromWriter: Sendable {
711701
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
712702
private let context: NIOLoopBound<ChannelHandlerContext>
713703
private let eventLoop: any EventLoop
714704

715-
struct NotWritableError: Error, CustomStringConvertible {
716-
var description = "No data must be written to `PostgresCopyFromWriter` after it has sent a CopyDone or CopyFail message, ie. after the closure producing the copy data has finished"
717-
}
718-
719705
init(handler: PostgresChannelHandler, context: ChannelHandlerContext, eventLoop: any EventLoop) {
720706
self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop)
721707
self.context = NIOLoopBound(context, eventLoop: eventLoop)
@@ -725,8 +711,12 @@ public struct PostgresCopyFromWriter: Sendable {
725711
/// Send data for a `COPY ... FROM STDIN` operation to the backend.
726712
public func write(_ byteBuffer: ByteBuffer) async throws {
727713
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
728-
eventLoop.executeImmediatelyOrSchedule {
714+
if eventLoop.inEventLoop {
729715
self.channelHandler.value.copyData(byteBuffer, context: self.context.value, readyForMoreWriteContinuation: continuation)
716+
} else {
717+
eventLoop.execute {
718+
self.channelHandler.value.copyData(byteBuffer, context: self.context.value, readyForMoreWriteContinuation: continuation)
719+
}
730720
}
731721
}
732722
}
@@ -735,8 +725,12 @@ public struct PostgresCopyFromWriter: Sendable {
735725
/// the backend.
736726
func done() async throws {
737727
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
738-
eventLoop.executeImmediatelyOrSchedule {
728+
if eventLoop.inEventLoop {
739729
self.channelHandler.value.sendCopyDone(continuation: continuation, context: self.context.value)
730+
} else {
731+
eventLoop.execute {
732+
self.channelHandler.value.sendCopyDone(continuation: continuation, context: self.context.value)
733+
}
740734
}
741735
}
742736
}
@@ -745,8 +739,12 @@ public struct PostgresCopyFromWriter: Sendable {
745739
/// the backend.
746740
func failed(error: any Error) async throws {
747741
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
748-
eventLoop.executeImmediatelyOrSchedule {
742+
if eventLoop.inEventLoop {
749743
self.channelHandler.value.sendCopyFailed(message: "\(error)", continuation: continuation, context: self.context.value)
744+
} else {
745+
eventLoop.execute {
746+
self.channelHandler.value.sendCopyFailed(message: "\(error)", continuation: continuation, context: self.context.value)
747+
}
750748
}
751749
}
752750
}

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ struct ConnectionStateMachine {
122122
case succeedClose(CloseCommandContext)
123123
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
124124
}
125+
126+
enum ChannelWritabilityChangedAction {
127+
case none
128+
case resumeContinuation(CheckedContinuation<Void, Never>)
129+
}
125130

126131
private var state: State
127132
private let requireBackendKeyData: Bool
@@ -679,13 +684,14 @@ struct ConnectionStateMachine {
679684
}
680685
}
681686

682-
mutating func channelWritabilityChanged(isWritable: Bool) {
687+
mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction {
683688
guard case .extendedQuery(var queryState, let connectionContext) = state else {
684-
return
689+
return .none
685690
}
686691
self.state = .modifying // avoid CoW
687-
queryState.channelWritabilityChanged(isWritable: isWritable)
692+
let action = queryState.channelWritabilityChanged(isWritable: isWritable)
688693
self.state = .extendedQuery(queryState, connectionContext)
694+
return action
689695
}
690696

691697
// MARK: - Running Queries -

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,13 +598,13 @@ struct ExtendedQueryStateMachine {
598598
}
599599
}
600600

601-
mutating func channelWritabilityChanged(isWritable: Bool) {
601+
mutating func channelWritabilityChanged(isWritable: Bool) -> ConnectionStateMachine.ChannelWritabilityChangedAction {
602602
guard case .copyingData(.pendingBackpressureRelieve(let continuation)) = state else {
603-
return
603+
return .none
604604
}
605-
self.avoidingStateMachineCoW { state in
605+
return self.avoidingStateMachineCoW { state in
606606
state = .copyingData(.readyToSend)
607-
continuation.resume()
607+
return .resumeContinuation(continuation)
608608
}
609609
}
610610

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,13 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
203203
}
204204

205205
func channelWritabilityChanged(context: ChannelHandlerContext) {
206-
self.state.channelWritabilityChanged(isWritable: context.channel.isWritable)
206+
let action = self.state.channelWritabilityChanged(isWritable: context.channel.isWritable)
207+
switch action {
208+
case .none:
209+
break
210+
case .resumeContinuation(let continuation):
211+
continuation.resume()
212+
}
207213
}
208214

209215
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {

0 commit comments

Comments
 (0)