diff --git a/Package.swift b/Package.swift index 99ea9667..4cce3abd 100644 --- a/Package.swift +++ b/Package.swift @@ -6,7 +6,9 @@ let defaultSwiftSettings: [SwiftSetting] = [ .enableExperimentalFeature( "AvailabilityMacro=LambdaSwift 2.0:macOS 15.0" - ) + ), + .enableUpcomingFeature("NonisolatedNonsendingByDefault"), + .enableUpcomingFeature("InferIsolatedConformances"), ] let package = Package( diff --git a/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler-pre-6.2.swift b/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler-pre-6.2.swift new file mode 100644 index 00000000..be0973f2 --- /dev/null +++ b/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler-pre-6.2.swift @@ -0,0 +1,491 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright SwiftAWSLambdaRuntime project authors +// Copyright (c) Amazon.com, Inc. or its affiliates. +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if swift(<6.2) +import Logging +import NIOCore +import NIOHTTP1 +import NIOPosix + +internal protocol LambdaChannelHandlerDelegate { + func connectionWillClose(channel: any Channel) + func connectionErrorHappened(_ error: any Error, channel: any Channel) +} + +@available(LambdaSwift 2.0, *) +internal final class LambdaChannelHandler { + let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix + + enum State { + case disconnected + case connected(ChannelHandlerContext, LambdaState) + case closing + + enum LambdaState { + /// this is the "normal" state. Transitions to `waitingForNextInvocation` + case idle + /// this is the state while we wait for an invocation. A next call is running. + /// Transitions to `waitingForResponse` + case waitingForNextInvocation(CheckedContinuation) + /// The invocation was forwarded to the handler and we wait for a response. + /// Transitions to `sendingResponse` or `sentResponse`. + case waitingForResponse + case sendingResponse + case sentResponse(CheckedContinuation) + } + } + + private var state: State = .disconnected + private var lastError: Error? + private var reusableErrorBuffer: ByteBuffer? + private let logger: Logger + private let delegate: Delegate + private let configuration: LambdaRuntimeClient.Configuration + + /// These are the default headers that must be sent along an invocation + let defaultHeaders: HTTPHeaders + /// These headers must be sent along an invocation or initialization error report + let errorHeaders: HTTPHeaders + /// These headers must be sent when streaming a large response + let largeResponseHeaders: HTTPHeaders + /// These headers must be sent when the handler streams its response + let streamingHeaders: HTTPHeaders + + init( + delegate: Delegate, + logger: Logger, + configuration: LambdaRuntimeClient.Configuration + ) { + self.delegate = delegate + self.logger = logger + self.configuration = configuration + self.defaultHeaders = [ + "host": "\(self.configuration.ip):\(self.configuration.port)", + "user-agent": .userAgent, + ] + self.errorHeaders = [ + "host": "\(self.configuration.ip):\(self.configuration.port)", + "user-agent": .userAgent, + "lambda-runtime-function-error-type": "Unhandled", + ] + self.largeResponseHeaders = [ + "host": "\(self.configuration.ip):\(self.configuration.port)", + "user-agent": .userAgent, + "transfer-encoding": "chunked", + ] + // https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html#runtimes-custom-response-streaming + // These are the headers returned by the Runtime to the Lambda Data plane. + // These are not the headers the Lambda Data plane sends to the caller of the Lambda function + // The developer of the function can set the caller's headers in the handler code. + self.streamingHeaders = [ + "host": "\(self.configuration.ip):\(self.configuration.port)", + "user-agent": .userAgent, + "Lambda-Runtime-Function-Response-Mode": "streaming", + // these are not used by this runtime client at the moment + // FIXME: the eror handling should inject these headers in the streamed response to report mid-stream errors + "Trailer": "Lambda-Runtime-Function-Error-Type, Lambda-Runtime-Function-Error-Body", + ] + } + + func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation { + switch self.state { + case .connected(let context, .idle): + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + self.state = .connected(context, .waitingForNextInvocation(continuation)) + self.sendNextRequest(context: context) + } + + case .connected(_, .sendingResponse), + .connected(_, .sentResponse), + .connected(_, .waitingForNextInvocation), + .connected(_, .waitingForResponse), + .closing: + fatalError("Invalid state: \(self.state)") + + case .disconnected: + throw LambdaRuntimeError(code: .connectionToControlPlaneLost) + } + } + + func reportError( + isolation: isolated (any Actor)? = #isolation, + _ error: any Error, + requestID: String + ) async throws { + switch self.state { + case .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendReportErrorRequest(requestID: requestID, error: error, context: context) + } + + case .connected(let context, .sendingResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseStreamingFailure(error: error, context: context) + } + + case .connected(_, .idle), + .connected(_, .sentResponse): + // The final response has already been sent. The only way to report the unhandled error + // now is to log it. Normally this library never logs higher than debug, we make an + // exception here, as there is no other way of reporting the error. + self.logger.error( + "Unhandled error after stream has finished", + metadata: [ + "lambda_request_id": "\(requestID)", + "lambda_error": "\(String(describing: error))", + ] + ) + + case .disconnected: + throw LambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + func writeResponseBodyPart( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer, + requestID: String, + hasCustomHeaders: Bool + ) async throws { + switch self.state { + case .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + self.state = .connected(context, .sendingResponse) + try await self.sendResponseBodyPart( + byteBuffer, + sendHeadWithRequestID: requestID, + context: context, + hasCustomHeaders: hasCustomHeaders + ) + + case .connected(let context, .sendingResponse): + + precondition(!hasCustomHeaders, "Programming error: Custom headers should not be sent in this state") + + try await self.sendResponseBodyPart( + byteBuffer, + sendHeadWithRequestID: nil, + context: context, + hasCustomHeaders: hasCustomHeaders + ) + + case .connected(_, .idle), + .connected(_, .sentResponse): + throw LambdaRuntimeError(code: .writeAfterFinishHasBeenSent) + + case .disconnected: + throw LambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + func finishResponseRequest( + isolation: isolated (any Actor)? = #isolation, + finalData: ByteBuffer?, + requestID: String + ) async throws { + switch self.state { + case .connected(_, .idle), + .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseFinish(finalData, sendHeadWithRequestID: requestID, context: context) + } + + case .connected(let context, .sendingResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseFinish(finalData, sendHeadWithRequestID: nil, context: context) + } + + case .connected(_, .sentResponse): + throw LambdaRuntimeError(code: .finishAfterFinishHasBeenSent) + + case .disconnected: + throw LambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + private func sendResponseBodyPart( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer, + sendHeadWithRequestID: String?, + context: ChannelHandlerContext, + hasCustomHeaders: Bool + ) async throws { + + if let requestID = sendHeadWithRequestID { + // TODO: This feels super expensive. We should be able to make this cheaper. requestIDs are fixed length. + let url = Consts.invocationURLPrefix + "/" + requestID + Consts.postResponseURLSuffix + + var headers = self.streamingHeaders + if hasCustomHeaders { + // this header is required by Function URL when the user sends custom status code or headers + headers.add(name: "Content-Type", value: "application/vnd.awslambda.http-integration-response") + } + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: headers + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + } + + let future = context.write(self.wrapOutboundOut(.body(.byteBuffer(byteBuffer)))) + context.flush() + try await future.get() + } + + private func sendResponseFinish( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer?, + sendHeadWithRequestID: String?, + context: ChannelHandlerContext + ) { + if let requestID = sendHeadWithRequestID { + // TODO: This feels quite expensive. We should be able to make this cheaper. requestIDs are fixed length. + let url = "\(Consts.invocationURLPrefix)/\(requestID)\(Consts.postResponseURLSuffix)" + + // If we have less than 6MB, we don't want to use the streaming API. If we have more + // than 6MB, we must use the streaming mode. + var headers: HTTPHeaders! + if byteBuffer?.readableBytes ?? 0 < 6_000_000 { + headers = self.defaultHeaders + headers.add(name: "content-length", value: "\(byteBuffer?.readableBytes ?? 0)") + } else { + headers = self.largeResponseHeaders + } + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: headers + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + } + + if let byteBuffer { + context.write(self.wrapOutboundOut(.body(.byteBuffer(byteBuffer))), promise: nil) + } + + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendNextRequest(context: ChannelHandlerContext) { + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: self.nextInvocationPath, + headers: self.defaultHeaders + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendReportErrorRequest(requestID: String, error: any Error, context: ChannelHandlerContext) { + // TODO: This feels quite expensive. We should be able to make this cheaper. requestIDs are fixed length + let url = "\(Consts.invocationURLPrefix)/\(requestID)\(Consts.postErrorURLSuffix)" + + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: self.errorHeaders + ) + + if self.reusableErrorBuffer == nil { + self.reusableErrorBuffer = context.channel.allocator.buffer(capacity: 1024) + } else { + self.reusableErrorBuffer!.clear() + } + + let errorResponse = ErrorResponse(errorType: "\(type(of: error))", errorMessage: "\(error)") + // TODO: Write this directly into our ByteBuffer + let bytes = errorResponse.toJSONBytes() + self.reusableErrorBuffer!.writeBytes(bytes) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + context.write(self.wrapOutboundOut(.body(.byteBuffer(self.reusableErrorBuffer!))), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendResponseStreamingFailure(error: any Error, context: ChannelHandlerContext) { + let trailers: HTTPHeaders = [ + "Lambda-Runtime-Function-Error-Type": "Unhandled", + "Lambda-Runtime-Function-Error-Body": "Requires base64", + ] + + context.write(self.wrapOutboundOut(.end(trailers)), promise: nil) + context.flush() + } + + func cancelCurrentRequestAndCloseConnection() { + fatalError("Unimplemented") + } +} + +@available(LambdaSwift 2.0, *) +extension LambdaChannelHandler: ChannelInboundHandler { + typealias OutboundIn = Never + typealias InboundIn = NIOHTTPClientResponseFull + typealias OutboundOut = HTTPClientRequestPart + + func handlerAdded(context: ChannelHandlerContext) { + if context.channel.isActive { + self.state = .connected(context, .idle) + } + } + + func channelActive(context: ChannelHandlerContext) { + switch self.state { + case .disconnected: + self.state = .connected(context, .idle) + case .connected: + break + case .closing: + fatalError("Invalid state: \(self.state)") + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let response = unwrapInboundIn(data) + + // handle response content + + switch self.state { + case .connected(let context, .waitingForNextInvocation(let continuation)): + do { + let metadata = try InvocationMetadata(headers: response.head.headers) + self.state = .connected(context, .waitingForResponse) + continuation.resume(returning: Invocation(metadata: metadata, event: response.body ?? ByteBuffer())) + } catch { + self.state = .closing + + self.delegate.connectionWillClose(channel: context.channel) + context.close(promise: nil) + continuation.resume( + throwing: LambdaRuntimeError(code: .invocationMissingMetadata, underlying: error) + ) + } + + case .connected(let context, .sentResponse(let continuation)): + if response.head.status == .accepted { + self.state = .connected(context, .idle) + continuation.resume() + } else { + self.state = .connected(context, .idle) + continuation.resume(throwing: LambdaRuntimeError(code: .unexpectedStatusCodeForRequest)) + } + + case .disconnected, .closing, .connected(_, _): + break + } + + // As defined in RFC 7230 Section 6.3: + // HTTP/1.1 defaults to the use of "persistent connections", allowing + // multiple requests and responses to be carried over a single + // connection. The "close" connection option is used to signal that a + // connection will not persist after the current request/response. HTTP + // implementations SHOULD support persistent connections. + // + // That's why we only assume the connection shall be closed if we receive + // a "connection = close" header. + let serverCloseConnection = + response.head.headers["connection"].contains(where: { $0.lowercased() == "close" }) + + let closeConnection = serverCloseConnection || response.head.version != .http1_1 + + if closeConnection { + // If we were succeeding the request promise here directly and closing the connection + // after succeeding the promise we may run into a race condition: + // + // The lambda runtime will ask for the next work item directly after a succeeded post + // response request. The desire for the next work item might be faster than the attempt + // to close the connection. This will lead to a situation where we try to the connection + // but the next request has already been scheduled on the connection that we want to + // close. For this reason we postpone succeeding the promise until the connection has + // been closed. This codepath will only be hit in the very, very unlikely event of the + // Lambda control plane demanding to close connection. (It's more or less only + // implemented to support http1.1 correctly.) This behavior is ensured with the test + // `LambdaTest.testNoKeepAliveServer`. + self.state = .closing + self.delegate.connectionWillClose(channel: context.channel) + context.close(promise: nil) + } + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + self.logger.trace( + "Channel error caught", + metadata: [ + "error": "\(error)" + ] + ) + // pending responses will fail with lastError in channelInactive since we are calling context.close + self.delegate.connectionErrorHappened(error, channel: context.channel) + + self.lastError = error + context.channel.close(promise: nil) + } + + func channelInactive(context: ChannelHandlerContext) { + // fail any pending responses with last error or assume peer disconnected + switch self.state { + case .connected(_, let lambdaState): + switch lambdaState { + case .waitingForNextInvocation(let continuation): + continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + case .sentResponse(let continuation): + continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + case .idle, .sendingResponse, .waitingForResponse: + break + } + self.state = .disconnected + default: + break + } + + // we don't need to forward channelInactive to the delegate, as the delegate observes the + // closeFuture + context.fireChannelInactive() + } +} +#endif diff --git a/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift b/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift index 6238fac4..06fd0c4a 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift @@ -13,6 +13,7 @@ // //===----------------------------------------------------------------------===// +#if swift(>=6.2) import Logging import NIOCore import NIOHTTP1 @@ -98,7 +99,7 @@ internal final class LambdaChannelHandler Invocation { + func nextInvocation() async throws -> Invocation { switch self.state { case .connected(let context, .idle): return try await withCheckedThrowingContinuation { @@ -120,7 +121,6 @@ internal final class LambdaChannelHandler