|
2 | 2 | // |
3 | 3 | // This source file is part of the SwiftAWSLambdaRuntime open source project |
4 | 4 | // |
5 | | -// Copyright (c) 2020 Apple Inc. and the SwiftAWSLambdaRuntime project authors |
| 5 | +// Copyright (c) 2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors |
6 | 6 | // Licensed under Apache License v2.0 |
7 | 7 | // |
8 | 8 | // See LICENSE.txt for license information |
@@ -76,7 +76,7 @@ extension Lambda { |
76 | 76 | /// 1. POST /invoke - the client posts the event to the lambda function |
77 | 77 | /// |
78 | 78 | /// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client. |
79 | | -private struct LambdaHTTPServer { |
| 79 | +internal struct LambdaHTTPServer { |
80 | 80 | private let invocationEndpoint: String |
81 | 81 |
|
82 | 82 | private let invocationPool = Pool<LocalServerInvocation>() |
@@ -166,17 +166,21 @@ private struct LambdaHTTPServer { |
166 | 166 | // consumed by iterating the group or by exiting the group. Since, we are never consuming |
167 | 167 | // the results of the group we need the group to automatically discard them; otherwise, this |
168 | 168 | // would result in a memory leak over time. |
169 | | - try await withThrowingDiscardingTaskGroup { taskGroup in |
170 | | - try await channel.executeThenClose { inbound in |
171 | | - for try await connectionChannel in inbound { |
172 | | - |
173 | | - taskGroup.addTask { |
174 | | - logger.trace("Handling a new connection") |
175 | | - await server.handleConnection(channel: connectionChannel, logger: logger) |
176 | | - logger.trace("Done handling the connection") |
| 169 | + try await withTaskCancellationHandler { |
| 170 | + try await withThrowingDiscardingTaskGroup { taskGroup in |
| 171 | + try await channel.executeThenClose { inbound in |
| 172 | + for try await connectionChannel in inbound { |
| 173 | + |
| 174 | + taskGroup.addTask { |
| 175 | + logger.trace("Handling a new connection") |
| 176 | + await server.handleConnection(channel: connectionChannel, logger: logger) |
| 177 | + logger.trace("Done handling the connection") |
| 178 | + } |
177 | 179 | } |
178 | 180 | } |
179 | 181 | } |
| 182 | + } onCancel: { |
| 183 | + channel.channel.close(promise: nil) |
180 | 184 | } |
181 | 185 | return .serverReturned(.success(())) |
182 | 186 | } catch { |
@@ -230,38 +234,42 @@ private struct LambdaHTTPServer { |
230 | 234 | // Note that this method is non-throwing and we are catching any error. |
231 | 235 | // We do this since we don't want to tear down the whole server when a single connection |
232 | 236 | // encounters an error. |
233 | | - do { |
234 | | - try await channel.executeThenClose { inbound, outbound in |
235 | | - for try await inboundData in inbound { |
236 | | - switch inboundData { |
237 | | - case .head(let head): |
238 | | - requestHead = head |
239 | | - |
240 | | - case .body(let body): |
241 | | - requestBody.setOrWriteImmutableBuffer(body) |
242 | | - |
243 | | - case .end: |
244 | | - precondition(requestHead != nil, "Received .end without .head") |
245 | | - // process the request |
246 | | - let response = try await self.processRequest( |
247 | | - head: requestHead, |
248 | | - body: requestBody, |
249 | | - logger: logger |
250 | | - ) |
251 | | - // send the responses |
252 | | - try await self.sendResponse( |
253 | | - response: response, |
254 | | - outbound: outbound, |
255 | | - logger: logger |
256 | | - ) |
257 | | - |
258 | | - requestHead = nil |
259 | | - requestBody = nil |
| 237 | + await withTaskCancellationHandler { |
| 238 | + do { |
| 239 | + try await channel.executeThenClose { inbound, outbound in |
| 240 | + for try await inboundData in inbound { |
| 241 | + switch inboundData { |
| 242 | + case .head(let head): |
| 243 | + requestHead = head |
| 244 | + |
| 245 | + case .body(let body): |
| 246 | + requestBody.setOrWriteImmutableBuffer(body) |
| 247 | + |
| 248 | + case .end: |
| 249 | + precondition(requestHead != nil, "Received .end without .head") |
| 250 | + // process the request |
| 251 | + let response = try await self.processRequest( |
| 252 | + head: requestHead, |
| 253 | + body: requestBody, |
| 254 | + logger: logger |
| 255 | + ) |
| 256 | + // send the responses |
| 257 | + try await self.sendResponse( |
| 258 | + response: response, |
| 259 | + outbound: outbound, |
| 260 | + logger: logger |
| 261 | + ) |
| 262 | + |
| 263 | + requestHead = nil |
| 264 | + requestBody = nil |
| 265 | + } |
260 | 266 | } |
261 | 267 | } |
| 268 | + } catch { |
| 269 | + logger.error("Hit error: \(error)") |
262 | 270 | } |
263 | | - } catch { |
264 | | - logger.error("Hit error: \(error)") |
| 271 | + } onCancel: { |
| 272 | + channel.channel.close(promise: nil) |
265 | 273 | } |
266 | 274 | } |
267 | 275 |
|
@@ -426,7 +434,7 @@ private struct LambdaHTTPServer { |
426 | 434 | /// A shared data structure to store the current invocation or response requests and the continuation objects. |
427 | 435 | /// This data structure is shared between instances of the HTTPHandler |
428 | 436 | /// (one instance to serve requests from the Lambda function and one instance to serve requests from the client invoking the lambda function). |
429 | | - private final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { |
| 437 | + internal final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { |
430 | 438 | typealias Element = T |
431 | 439 |
|
432 | 440 | enum State: ~Copyable { |
@@ -462,26 +470,38 @@ private struct LambdaHTTPServer { |
462 | 470 | return nil |
463 | 471 | } |
464 | 472 |
|
465 | | - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in |
466 | | - let nextAction = self.lock.withLock { state -> T? in |
467 | | - switch consume state { |
468 | | - case .buffer(var buffer): |
469 | | - if let first = buffer.popFirst() { |
470 | | - state = .buffer(buffer) |
471 | | - return first |
472 | | - } else { |
473 | | - state = .continuation(continuation) |
474 | | - return nil |
475 | | - } |
| 473 | + return try await withTaskCancellationHandler { |
| 474 | + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in |
| 475 | + let nextAction = self.lock.withLock { state -> T? in |
| 476 | + switch consume state { |
| 477 | + case .buffer(var buffer): |
| 478 | + if let first = buffer.popFirst() { |
| 479 | + state = .buffer(buffer) |
| 480 | + return first |
| 481 | + } else { |
| 482 | + state = .continuation(continuation) |
| 483 | + return nil |
| 484 | + } |
476 | 485 |
|
477 | | - case .continuation: |
478 | | - fatalError("Concurrent invocations to next(). This is illegal.") |
| 486 | + case .continuation: |
| 487 | + fatalError("Concurrent invocations to next(). This is illegal.") |
| 488 | + } |
479 | 489 | } |
480 | | - } |
481 | 490 |
|
482 | | - guard let nextAction else { return } |
| 491 | + guard let nextAction else { return } |
483 | 492 |
|
484 | | - continuation.resume(returning: nextAction) |
| 493 | + continuation.resume(returning: nextAction) |
| 494 | + } |
| 495 | + } onCancel: { |
| 496 | + self.lock.withLock { state in |
| 497 | + switch consume state { |
| 498 | + case .buffer(let buffer): |
| 499 | + state = .buffer(buffer) |
| 500 | + case .continuation(let continuation): |
| 501 | + continuation?.resume(throwing: CancellationError()) |
| 502 | + state = .buffer([]) |
| 503 | + } |
| 504 | + } |
485 | 505 | } |
486 | 506 | } |
487 | 507 |
|
|
0 commit comments