Skip to content

Commit d4d1a2e

Browse files
authored
Make 'finish()' 'async' (#2044)
Motivation: Finishing writes should be `async` as the underlying writer may need to flush and write out any buffered data. Modifications: - Mark `finish()` as `async` - Refactor the in-proc client transport slightly to avoid async calls while holding a lock Result: `finish` is `async`
1 parent 6f396ca commit d4d1a2e

File tree

11 files changed

+95
-87
lines changed

11 files changed

+95
-87
lines changed

Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ internal enum ClientStreamExecutor {
9595

9696
switch result {
9797
case .success:
98-
stream.finish()
98+
await stream.finish()
9999
case .failure(let error):
100-
stream.finish(throwing: error)
100+
await stream.finish(throwing: error)
101101
}
102102
}
103103

Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct ServerRPCExecutor {
5858
// Stream can't be handled; write an error status and close.
5959
let status = Status(code: Status.Code(error.code), message: error.message)
6060
try? await stream.outbound.write(.status(status, error.metadata))
61-
stream.outbound.finish()
61+
await stream.outbound.finish()
6262
}
6363
}
6464

@@ -231,7 +231,7 @@ struct ServerRPCExecutor {
231231
}
232232

233233
try? await outbound.write(.status(status, metadata))
234-
outbound.finish()
234+
await outbound.finish()
235235
}
236236

237237
@inlinable

Sources/GRPCCore/Call/Server/RPCRouter.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ extension RPCRouter {
155155
// If this throws then the stream must be closed which we can't do anything about, so ignore
156156
// any error.
157157
try? await stream.outbound.write(.status(.rpcNotImplemented, [:]))
158-
stream.outbound.finish()
158+
await stream.outbound.finish()
159159
}
160160
}
161161
}

Sources/GRPCCore/Streaming/RPCWriter+Closable.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,17 @@ extension RPCWriter {
5555
/// All writes after ``finish()`` has been called should result in an error
5656
/// being thrown.
5757
@inlinable
58-
public func finish() {
59-
self.writer.finish()
58+
public func finish() async {
59+
await self.writer.finish()
6060
}
6161

6262
/// Indicate to the writer that no more writes are to be accepted because an error occurred.
6363
///
6464
/// All writes after ``finish(throwing:)`` has been called should result in an error
6565
/// being thrown.
6666
@inlinable
67-
public func finish(throwing error: any Error) {
68-
self.writer.finish(throwing: error)
67+
public func finish(throwing error: any Error) async {
68+
await self.writer.finish(throwing: error)
6969
}
7070
}
7171
}

Sources/GRPCCore/Streaming/RPCWriterProtocol.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ public protocol ClosableRPCWriterProtocol<Element>: RPCWriterProtocol {
5757
///
5858
/// All writes after ``finish()`` has been called should result in an error
5959
/// being thrown.
60-
func finish()
60+
func finish() async
6161

6262
/// Indicate to the writer that no more writes are to be accepted because an error occurred.
6363
///
6464
/// All writes after ``finish(throwing:)`` has been called should result in an error
6565
/// being thrown.
66-
func finish(throwing error: any Error)
66+
func finish(throwing error: any Error) async
6767
}
6868

6969
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)

Sources/GRPCInProcessTransport/InProcessClientTransport.swift

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ public final class InProcessClientTransport: ClientTransport {
179179
}
180180

181181
for (clientStream, serverStream) in openStreams {
182-
clientStream.outbound.finish(throwing: CancellationError())
183-
serverStream.outbound.finish(throwing: CancellationError())
182+
await clientStream.outbound.finish(throwing: CancellationError())
183+
await serverStream.outbound.finish(throwing: CancellationError())
184184
}
185185
}
186186

@@ -265,7 +265,7 @@ public final class InProcessClientTransport: ClientTransport {
265265
try Task.checkCancellation()
266266
}
267267

268-
let streamID = try self.state.withLock { state in
268+
let acceptStream: Result<Int, RPCError> = self.state.withLock { state in
269269
switch state {
270270
case .unconnected:
271271
// The state cannot be unconnected because if it was, then the above
@@ -281,56 +281,64 @@ public final class InProcessClientTransport: ClientTransport {
281281
connectedState.openStreams[streamID] = (clientStream, serverStream)
282282
connectedState.nextStreamID += 1
283283
state = .connected(connectedState)
284+
return .success(streamID)
284285
} catch let acceptStreamError as RPCError {
285-
serverStream.outbound.finish(throwing: acceptStreamError)
286-
clientStream.outbound.finish(throwing: acceptStreamError)
287-
throw acceptStreamError
286+
return .failure(acceptStreamError)
288287
} catch {
289-
serverStream.outbound.finish(throwing: error)
290-
clientStream.outbound.finish(throwing: error)
291-
throw RPCError(code: .unknown, message: "Unknown error: \(error).")
288+
return .failure(RPCError(code: .unknown, message: "Unknown error: \(error)."))
292289
}
293-
return streamID
294290

295291
case .closed:
296-
let error = RPCError(
297-
code: .failedPrecondition,
298-
message: "The client transport is closed."
299-
)
300-
serverStream.outbound.finish(throwing: error)
301-
clientStream.outbound.finish(throwing: error)
302-
throw error
292+
let error = RPCError(code: .failedPrecondition, message: "The client transport is closed.")
293+
return .failure(error)
303294
}
304295
}
305296

306-
defer {
307-
clientStream.outbound.finish()
308-
309-
let maybeEndContinuation = self.state.withLock { state in
310-
switch state {
311-
case .unconnected:
312-
// The state cannot be unconnected at this point, because if we made
313-
// it this far, it's because the transport was connected.
314-
// Once connected, it's impossible to transition back to unconnected,
315-
// so this is an invalid state.
316-
fatalError("Invalid state")
317-
case .connected(var connectedState):
318-
connectedState.openStreams.removeValue(forKey: streamID)
319-
state = .connected(connectedState)
320-
case .closed(var closedState):
321-
closedState.openStreams.removeValue(forKey: streamID)
322-
state = .closed(closedState)
323-
if closedState.openStreams.isEmpty {
324-
// This was the last open stream: signal the closure of the client.
325-
return closedState.signalEndContinuation
326-
}
327-
}
328-
return nil
297+
switch acceptStream {
298+
case .success(let streamID):
299+
let streamHandlingResult: Result<T, any Error>
300+
do {
301+
let result = try await closure(clientStream)
302+
streamHandlingResult = .success(result)
303+
} catch {
304+
streamHandlingResult = .failure(error)
329305
}
330-
maybeEndContinuation?.finish()
306+
307+
await clientStream.outbound.finish()
308+
self.removeStream(id: streamID)
309+
310+
return try streamHandlingResult.get()
311+
312+
case .failure(let error):
313+
await serverStream.outbound.finish(throwing: error)
314+
await clientStream.outbound.finish(throwing: error)
315+
throw error
331316
}
317+
}
332318

333-
return try await closure(clientStream)
319+
private func removeStream(id streamID: Int) {
320+
let maybeEndContinuation = self.state.withLock { state in
321+
switch state {
322+
case .unconnected:
323+
// The state cannot be unconnected at this point, because if we made
324+
// it this far, it's because the transport was connected.
325+
// Once connected, it's impossible to transition back to unconnected,
326+
// so this is an invalid state.
327+
fatalError("Invalid state")
328+
case .connected(var connectedState):
329+
connectedState.openStreams.removeValue(forKey: streamID)
330+
state = .connected(connectedState)
331+
case .closed(var closedState):
332+
closedState.openStreams.removeValue(forKey: streamID)
333+
state = .closed(closedState)
334+
if closedState.openStreams.isEmpty {
335+
// This was the last open stream: signal the closure of the client.
336+
return closedState.signalEndContinuation
337+
}
338+
}
339+
return nil
340+
}
341+
maybeEndContinuation?.finish()
334342
}
335343

336344
/// Returns the execution configuration for a given method.

Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness+ServerBehavior.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {
7474

7575
try await stream.outbound.write(contentsOf: response)
7676
try await stream.outbound.write(.status(Status(code: .ok, message: ""), [:]))
77-
stream.outbound.finish()
77+
await stream.outbound.finish()
7878
}
7979
}
8080

@@ -90,7 +90,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {
9090
// All error codes are valid status codes, '!' is safe.
9191
let status = Status(code: Status.Code(error.code), message: error.message)
9292
try await stream.outbound.write(.status(status, error.metadata))
93-
stream.outbound.finish()
93+
await stream.outbound.finish()
9494
}
9595
}
9696

@@ -99,7 +99,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {
9999
XCTFail("Server accepted unexpected stream")
100100
let status = Status(code: .unknown, message: "Unexpected stream")
101101
try await stream.outbound.write(.status(status, [:]))
102-
stream.outbound.finish()
102+
await stream.outbound.finish()
103103
}
104104
}
105105

Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ final class ServerRPCExecutorTests: XCTestCase {
2424
let harness = ServerRPCExecutorTestHarness()
2525
try await harness.execute(handler: .echo) { inbound in
2626
try await inbound.write(.metadata(["foo": "bar"]))
27-
inbound.finish()
27+
await inbound.finish()
2828
} consumer: { outbound in
2929
let parts = try await outbound.collect()
3030
XCTAssertEqual(
@@ -42,7 +42,7 @@ final class ServerRPCExecutorTests: XCTestCase {
4242
try await harness.execute(handler: .echo) { inbound in
4343
try await inbound.write(.metadata(["foo": "bar"]))
4444
try await inbound.write(.message([0]))
45-
inbound.finish()
45+
await inbound.finish()
4646
} consumer: { outbound in
4747
let parts = try await outbound.collect()
4848
XCTAssertEqual(
@@ -63,7 +63,7 @@ final class ServerRPCExecutorTests: XCTestCase {
6363
try await inbound.write(.message([0]))
6464
try await inbound.write(.message([1]))
6565
try await inbound.write(.message([2]))
66-
inbound.finish()
66+
await inbound.finish()
6767
} consumer: { outbound in
6868
let parts = try await outbound.collect()
6969
XCTAssertEqual(
@@ -94,7 +94,7 @@ final class ServerRPCExecutorTests: XCTestCase {
9494
} producer: { inbound in
9595
try await inbound.write(.metadata(["foo": "bar"]))
9696
try await inbound.write(.message(Array("\"hello\"".utf8)))
97-
inbound.finish()
97+
await inbound.finish()
9898
} consumer: { outbound in
9999
let parts = try await outbound.collect()
100100
XCTAssertEqual(
@@ -125,7 +125,7 @@ final class ServerRPCExecutorTests: XCTestCase {
125125
try await inbound.write(.metadata(["foo": "bar"]))
126126
try await inbound.write(.message(Array("\"hello\"".utf8)))
127127
try await inbound.write(.message(Array("\"world\"".utf8)))
128-
inbound.finish()
128+
await inbound.finish()
129129
} consumer: { outbound in
130130
let parts = try await outbound.collect()
131131
XCTAssertEqual(
@@ -151,7 +151,7 @@ final class ServerRPCExecutorTests: XCTestCase {
151151
}
152152
} producer: { inbound in
153153
try await inbound.write(.metadata(["foo": "bar"]))
154-
inbound.finish()
154+
await inbound.finish()
155155
} consumer: { outbound in
156156
let parts = try await outbound.collect()
157157
XCTAssertEqual(
@@ -167,7 +167,7 @@ final class ServerRPCExecutorTests: XCTestCase {
167167
func testEmptyInbound() async throws {
168168
let harness = ServerRPCExecutorTestHarness()
169169
try await harness.execute(handler: .echo) { inbound in
170-
inbound.finish()
170+
await inbound.finish()
171171
} consumer: { outbound in
172172
let part = try await outbound.collect().first
173173
XCTAssertStatus(part) { status, _ in
@@ -180,7 +180,7 @@ final class ServerRPCExecutorTests: XCTestCase {
180180
let harness = ServerRPCExecutorTestHarness()
181181
try await harness.execute(handler: .echo) { inbound in
182182
try await inbound.write(.message([0]))
183-
inbound.finish()
183+
await inbound.finish()
184184
} consumer: { outbound in
185185
let part = try await outbound.collect().first
186186
XCTAssertStatus(part) { status, _ in
@@ -192,7 +192,7 @@ final class ServerRPCExecutorTests: XCTestCase {
192192
func testInboundStreamThrows() async throws {
193193
let harness = ServerRPCExecutorTestHarness()
194194
try await harness.execute(handler: .echo) { inbound in
195-
inbound.finish(throwing: RPCError(code: .aborted, message: ""))
195+
await inbound.finish(throwing: RPCError(code: .aborted, message: ""))
196196
} consumer: { outbound in
197197
let part = try await outbound.collect().first
198198
XCTAssertStatus(part) { status, _ in
@@ -206,7 +206,7 @@ final class ServerRPCExecutorTests: XCTestCase {
206206
let harness = ServerRPCExecutorTestHarness()
207207
try await harness.execute(handler: .throwing(SomeError())) { inbound in
208208
try await inbound.write(.metadata([:]))
209-
inbound.finish()
209+
await inbound.finish()
210210
} consumer: { outbound in
211211
let part = try await outbound.collect().first
212212
XCTAssertStatus(part) { status, _ in
@@ -220,7 +220,7 @@ final class ServerRPCExecutorTests: XCTestCase {
220220
let harness = ServerRPCExecutorTestHarness()
221221
try await harness.execute(handler: .throwing(error)) { inbound in
222222
try await inbound.write(.metadata([:]))
223-
inbound.finish()
223+
await inbound.finish()
224224
} consumer: { outbound in
225225
let part = try await outbound.collect().first
226226
XCTAssertStatus(part) { status, metadata in
@@ -247,7 +247,7 @@ final class ServerRPCExecutorTests: XCTestCase {
247247
return ServerResponse.Stream(error: RPCError(code: .failedPrecondition, message: ""))
248248
} producer: { inbound in
249249
try await inbound.write(.metadata(["grpc-timeout": "1000n"]))
250-
inbound.finish()
250+
await inbound.finish()
251251
} consumer: { outbound in
252252
let part = try await outbound.collect().first
253253
XCTAssertStatus(part) { status, _ in
@@ -277,7 +277,7 @@ final class ServerRPCExecutorTests: XCTestCase {
277277
)
278278
} producer: { inbound in
279279
try await inbound.write(.metadata([:]))
280-
inbound.finish()
280+
await inbound.finish()
281281
} consumer: { outbound in
282282
let part = try await outbound.collect().first
283283
XCTAssertStatus(part) { status, metadata in
@@ -302,7 +302,7 @@ final class ServerRPCExecutorTests: XCTestCase {
302302

303303
try await harness.execute(handler: .echo) { inbound in
304304
try await inbound.write(.metadata([:]))
305-
inbound.finish()
305+
await inbound.finish()
306306
} consumer: { outbound in
307307
let parts = try await outbound.collect()
308308
XCTAssertEqual(parts, [.metadata([:]), .status(.ok, [:])])
@@ -327,7 +327,7 @@ final class ServerRPCExecutorTests: XCTestCase {
327327

328328
try await harness.execute(handler: .echo) { inbound in
329329
try await inbound.write(.metadata([:]))
330-
inbound.finish()
330+
await inbound.finish()
331331
} consumer: { outbound in
332332
let parts = try await outbound.collect()
333333
XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: ""), [:])])
@@ -345,7 +345,7 @@ final class ServerRPCExecutorTests: XCTestCase {
345345

346346
try await harness.execute(handler: .echo) { inbound in
347347
try await inbound.write(.metadata([:]))
348-
inbound.finish()
348+
await inbound.finish()
349349
} consumer: { outbound in
350350
let parts = try await outbound.collect()
351351
XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: "Unavailable"), [:])])

0 commit comments

Comments
 (0)