Skip to content

Commit cca91b8

Browse files
committed
Address review comments
1 parent c5f2928 commit cca91b8

File tree

3 files changed

+44
-57
lines changed

3 files changed

+44
-57
lines changed

Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
/// Handle to send data for a `COPY ... FROM STDIN` query to the backend.
22
public struct PostgresCopyFromWriter: Sendable {
3-
/// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed.
4-
///
5-
/// The `PostgresCopyFromWriter` should cancel the data transfer.
6-
public struct CopyCancellationError: Error {
7-
/// The error that the backend sent us which cancelled the data transfer.
8-
///
9-
/// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before
10-
/// new data is written by `write`.
11-
public let underlyingError: PSQLError
12-
}
13-
143
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
154
private let eventLoop: any EventLoop
165

@@ -42,9 +31,9 @@ public struct PostgresCopyFromWriter: Sendable {
4231

4332
/// Send data for a `COPY ... FROM STDIN` operation to the backend.
4433
///
45-
/// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws
46-
/// a `CopyCancellationError`.
47-
public func write(_ byteBuffer: ByteBuffer) async throws {
34+
/// - Throws: If an error occurs during the write of if the backend sent an `ErrorResponse` during the copy
35+
/// operation, eg. to indicate that a **previous** `write` call had an invalid format.
36+
public func write(_ byteBuffer: ByteBuffer, isolation: isolated (any Actor)? = #isolation) async throws {
4837
// Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the
4938
// `writeData` closure. It is likely that the user would forget to do so.
5039
try Task.checkCancellation()
@@ -82,7 +71,7 @@ public struct PostgresCopyFromWriter: Sendable {
8271

8372
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
8473
/// the backend.
85-
func done() async throws {
74+
func done(isolation: isolated (any Actor)? = #isolation) async throws {
8675
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
8776
if eventLoop.inEventLoop {
8877
self.channelHandler.value.sendCopyDone(continuation: continuation)
@@ -96,37 +85,43 @@ public struct PostgresCopyFromWriter: Sendable {
9685

9786
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to
9887
/// the backend.
99-
func failed(error: any Error) async throws {
88+
func failed(error: any Error, isolation: isolated (any Actor)? = #isolation) async throws {
10089
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
101-
// TODO: Is it OK to use string interpolation to construct an error description to be sent to the backend
102-
// here? We could also use a generic description, it doesn't really matter since we throw the user's error
103-
// in `copyFrom`.
10490
if eventLoop.inEventLoop {
105-
self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation)
91+
self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation)
10692
} else {
10793
eventLoop.execute {
108-
self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation)
94+
self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation)
10995
}
11096
}
11197
}
11298
}
11399
}
114100

115101
/// Specifies the format in which data is transferred to the backend in a COPY operation.
116-
public enum PostgresCopyFromFormat: Sendable {
102+
///
103+
/// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings
104+
/// and their default values.
105+
public struct PostgresCopyFromFormat: Sendable {
117106
/// Options that can be used to modify the `text` format of a COPY operation.
118107
public struct TextOptions: Sendable {
119108
/// The delimiter that separates columns in the data.
120109
///
121110
/// See the `DELIMITER` option in Postgres's `COPY` command.
122-
///
123-
/// Uses the default delimiter of the format
124111
public var delimiter: UnicodeScalar? = nil
125112

126113
public init() {}
127114
}
128115

129-
case text(TextOptions)
116+
enum Format {
117+
case text(TextOptions)
118+
}
119+
120+
var format: Format
121+
122+
public static func text(_ options: TextOptions) -> PostgresCopyFromFormat {
123+
return PostgresCopyFromFormat(format: .text(options))
124+
}
130125
}
131126

132127
/// Create a `COPY ... FROM STDIN` query based on the given parameters.
@@ -138,14 +133,17 @@ private func buildCopyFromQuery(
138133
columns: [StaticString] = [],
139134
format: PostgresCopyFromFormat
140135
) -> PostgresQuery {
141-
// TODO: Should we put the table and column names in quotes to make them case-sensitive?
142-
var query = "COPY \(table)"
136+
var query = """
137+
COPY "\(table)"
138+
"""
143139
if !columns.isEmpty {
144-
query += "(" + columns.map(\.description).joined(separator: ",") + ")"
140+
query += "("
141+
query += columns.map { #"""# + $0.description + #"""# }.joined(separator: ",")
142+
query += ")"
145143
}
146144
query += " FROM STDIN"
147145
var queryOptions: [String] = []
148-
switch format {
146+
switch format.format {
149147
case .text(let options):
150148
queryOptions.append("FORMAT text")
151149
if let delimiter = options.delimiter {
@@ -179,6 +177,7 @@ extension PostgresConnection {
179177
columns: [StaticString] = [],
180178
format: PostgresCopyFromFormat = .text(.init()),
181179
logger: Logger,
180+
isolation: isolated (any Actor)? = #isolation,
182181
file: String = #fileID,
183182
line: Int = #line,
184183
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void
@@ -205,22 +204,13 @@ extension PostgresConnection {
205204
// threw instead of the one that got relayed back, so it's better to ignore the error here.
206205
// - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts
207206
// the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger
208-
// a `Sync` that takes the backend out of copy mode. If `writeData` threw the `CopyCancellationError`
209-
// from the `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it
210-
// doesn't matter that we ignore the error here. If the user threw some other error, it's better to honor
211-
// the user's error.
207+
// a `Sync` that takes the backend out of copy mode. If `writeData` threw the error from from the
208+
// `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it doesn't
209+
// matter that we ignore the error here. If the user threw some other error, it's better to honor the
210+
// user's error.
212211
try? await writer.failed(error: error)
213212

214-
if let error = error as? PostgresCopyFromWriter.CopyCancellationError {
215-
// If we receive a `CopyCancellationError` that is with almost certain likelihood because
216-
// `PostgresCopyFromWriter.write` threw it - otherwise the user must have saved a previous
217-
// `PostgresCopyFromWriter` error, which is very unlikely.
218-
// Throw the underlying error because that contains the error message that was sent by the backend and
219-
// is most actionable by the user.
220-
throw error.underlyingError
221-
} else {
222-
throw error
223-
}
213+
throw error
224214
}
225215

226216
// `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during
@@ -230,5 +220,4 @@ extension PostgresConnection {
230220
// above.
231221
try await writer.done()
232222
}
233-
234223
}

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ struct ExtendedQueryStateMachine {
427427
if case .error(let error) = self.state {
428428
// The backend sent us an ErrorResponse during the copy operation. Indicate to the client that it should
429429
// abort the data transfer.
430-
promise.fail(PostgresCopyFromWriter.CopyCancellationError(underlyingError: error))
430+
promise.fail(error)
431431
return
432432
}
433433
guard case .copyingData(.readyToSend) = self.state else {

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ class PostgresConnectionTests: XCTestCase {
629629
try await assertCopyFrom { writer in
630630
try await writer.write(ByteBuffer(staticString: "1\tAlice\n"))
631631
} validateCopyRequest: { copyRequest in
632-
XCTAssertEqual(copyRequest.parse.query, "COPY copy_table(id,name) FROM STDIN WITH (FORMAT text)")
632+
XCTAssertEqual(copyRequest.parse.query, #"COPY "copy_table"("id","name") FROM STDIN WITH (FORMAT text)"#)
633633
XCTAssertEqual(copyRequest.bind.parameters, [])
634634
} mockBackend: { channel, _ in
635635
let data = try await channel.waitForCopyData()
@@ -646,7 +646,7 @@ class PostgresConnectionTests: XCTestCase {
646646
try await assertCopyFrom(format: .text(options)) { writer in
647647
try await writer.write(ByteBuffer(staticString: "1,Alice\n"))
648648
} validateCopyRequest: { copyRequest in
649-
XCTAssertEqual(copyRequest.parse.query, #"COPY copy_table(id,name) FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#)
649+
XCTAssertEqual(copyRequest.parse.query, #"COPY "copy_table"("id","name") FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#)
650650
XCTAssertEqual(copyRequest.bind.parameters, [])
651651
} mockBackend: { channel, _ in
652652
let data = try await channel.waitForCopyData()
@@ -657,19 +657,17 @@ class PostgresConnectionTests: XCTestCase {
657657
}
658658

659659
func testCopyFromWriterFails() async throws {
660-
struct MyError: Error, CustomStringConvertible {
661-
var description: String { "My error" }
662-
}
660+
struct MyError: Error {}
663661

664662
try await assertCopyFrom { writer in
665663
throw MyError()
666664
} validateCopyFromError: { error in
667665
XCTAssert(error is MyError, "Expected error of type MyError, got \(error)")
668666
} mockBackend: { channel, _ in
669667
let data = try await channel.waitForCopyData()
670-
XCTAssertEqual(data.result, .failed(message: "My error"))
668+
XCTAssertEqual(data.result, .failed(message: "Client failed copy"))
671669
try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [
672-
.message: "COPY from stdin failed: My error",
670+
.message: "COPY from stdin failed: Client failed copy",
673671
.sqlState : "57014" // query_canceled
674672
])))
675673
}
@@ -752,7 +750,7 @@ class PostgresConnectionTests: XCTestCase {
752750
try await writer.write(ByteBuffer(staticString: "2\tBob\n"))
753751
XCTFail("Expected error to be thrown")
754752
} catch {
755-
XCTAssert(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)")
753+
XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02")
756754
throw error
757755
}
758756
} validateCopyFromError: { error in
@@ -769,7 +767,7 @@ class PostgresConnectionTests: XCTestCase {
769767
}
770768
}
771769

772-
func testCopyFromCallerDoesNotRethrowCopyCancellationError() async throws {
770+
func testCopyFromCallerDoesNotRethrowFromWriteCall() async throws {
773771
struct MyError: Error, CustomStringConvertible {
774772
var description: String { "My error" }
775773
}
@@ -785,7 +783,7 @@ class PostgresConnectionTests: XCTestCase {
785783
try await writer.write(ByteBuffer(staticString: "2\tBob\n"))
786784
XCTFail("Expected error to be thrown")
787785
} catch {
788-
XCTAssert(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)")
786+
XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02")
789787
throw MyError()
790788
}
791789
} validateCopyFromError: { error in
@@ -875,10 +873,10 @@ class PostgresConnectionTests: XCTestCase {
875873
cancelCopy()
876874

877875
let data = try await channel.waitForCopyData()
878-
XCTAssertEqual(data.result, .failed(message: "CancellationError()"))
876+
XCTAssertEqual(data.result, .failed(message: "Client failed copy"))
879877

880878
try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [
881-
.message: "COPY from stdin failed: CancellationError()",
879+
.message: "COPY from stdin failed: Client failed copy",
882880
.sqlState : "57014" // query_canceled
883881
])))
884882
}

0 commit comments

Comments
 (0)