|
| 1 | +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +using Amazon.Lambda.Model; |
| 5 | +using Amazon.Lambda.SQSEvents; |
| 6 | +using Amazon.Runtime; |
| 7 | +using Amazon.SQS.Model; |
| 8 | +using Amazon.SQS; |
| 9 | +using System.Text.Json; |
| 10 | +using Amazon.Lambda.TestTool.Services; |
| 11 | + |
| 12 | +namespace Amazon.Lambda.TestTool.Processes.SQSEventSource; |
| 13 | + |
| 14 | +/// <summary> |
| 15 | +/// IHostedService that will run continually polling the SQS queue for messages and invoking the connected |
| 16 | +/// Lambda function with the polled messages. |
| 17 | +/// </summary> |
| 18 | +public class SQSEventSourceBackgroundService : BackgroundService |
| 19 | +{ |
| 20 | + private static readonly List<string> DefaultAttributesToReceive = new List<string> { "All" }; |
| 21 | + private static readonly JsonSerializerOptions _jsonOptions = new JsonSerializerOptions |
| 22 | + { |
| 23 | + PropertyNamingPolicy = JsonNamingPolicy.CamelCase |
| 24 | + }; |
| 25 | + |
| 26 | + private readonly ILogger<SQSEventSourceProcess> _logger; |
| 27 | + private readonly IAmazonSQS _sqsClient; |
| 28 | + private readonly ILambdaClient _lambdaClient; |
| 29 | + private readonly SQSEventSourceBackgroundServiceConfig _config; |
| 30 | + |
| 31 | + /// <summary> |
| 32 | + /// Constructs instance of <see cref="SQSEventSourceBackgroundService"/>. |
| 33 | + /// </summary> |
| 34 | + /// <param name="logger">The logger</param> |
| 35 | + /// <param name="sqsClient">The SQS client used to poll messages from a queue.</param> |
| 36 | + /// <param name="config">The config of the service</param> |
| 37 | + /// <param name="lambdaClient">The Lambda client that can use a different endpoint for each invoke request.</param> |
| 38 | + public SQSEventSourceBackgroundService(ILogger<SQSEventSourceProcess> logger, IAmazonSQS sqsClient, SQSEventSourceBackgroundServiceConfig config, ILambdaClient lambdaClient) |
| 39 | + { |
| 40 | + _logger = logger; |
| 41 | + _sqsClient = sqsClient; |
| 42 | + _config = config; |
| 43 | + _lambdaClient = lambdaClient; |
| 44 | + } |
| 45 | + |
| 46 | + private async Task<string> GetQueueArn(CancellationToken stoppingToken) |
| 47 | + { |
| 48 | + var response = await _sqsClient.GetQueueAttributesAsync(new GetQueueAttributesRequest |
| 49 | + { |
| 50 | + QueueUrl = _config.QueueUrl, |
| 51 | + AttributeNames = new List<string> { "QueueArn" } |
| 52 | + }, stoppingToken); |
| 53 | + |
| 54 | + return response.QueueARN; |
| 55 | + } |
| 56 | + |
| 57 | + /// <summary> |
| 58 | + /// Execute the SQSEventSourceBackgroundService. |
| 59 | + /// </summary> |
| 60 | + /// <param name="stoppingToken">CancellationToken used to end the service.</param> |
| 61 | + /// <returns>Task for the background service.</returns> |
| 62 | + protected override async Task ExecuteAsync(CancellationToken stoppingToken) |
| 63 | + { |
| 64 | + // The queue arn is needed for creating the Lambda event. |
| 65 | + var queueArn = await GetQueueArn(stoppingToken); |
| 66 | + _logger.LogInformation("Starting polling for messages on SQS queue: {queueArn}", queueArn); |
| 67 | + while (!stoppingToken.IsCancellationRequested) |
| 68 | + { |
| 69 | + try |
| 70 | + { |
| 71 | + _logger.LogDebug("Polling {queueUrl} for messages", _config.QueueUrl); |
| 72 | + // Read a message from the queue using the ExternalCommands console application. |
| 73 | + var response = await _sqsClient.ReceiveMessageAsync(new ReceiveMessageRequest |
| 74 | + { |
| 75 | + QueueUrl = _config.QueueUrl, |
| 76 | + WaitTimeSeconds = 20, |
| 77 | + MessageAttributeNames = DefaultAttributesToReceive, |
| 78 | + MessageSystemAttributeNames = DefaultAttributesToReceive, |
| 79 | + MaxNumberOfMessages = _config.BatchSize, |
| 80 | + VisibilityTimeout = _config.VisibilityTimeout, |
| 81 | + }, stoppingToken); |
| 82 | + |
| 83 | + if (stoppingToken.IsCancellationRequested) |
| 84 | + { |
| 85 | + return; |
| 86 | + } |
| 87 | + if (response.Messages == null || response.Messages.Count == 0) |
| 88 | + { |
| 89 | + _logger.LogDebug("No messages received from while polling SQS"); |
| 90 | + // Since there are no messages, sleep a bit to wait for messages to come. |
| 91 | + await Task.Delay(200); |
| 92 | + continue; |
| 93 | + } |
| 94 | + |
| 95 | + var lambdaPayload = new |
| 96 | + { |
| 97 | + Records = ConvertToLambdaMessages(response.Messages, _sqsClient.Config.RegionEndpoint.SystemName, queueArn) |
| 98 | + }; |
| 99 | + |
| 100 | + var invokeRequest = new InvokeRequest |
| 101 | + { |
| 102 | + InvocationType = InvocationType.RequestResponse, |
| 103 | + FunctionName = _config.FunctionName, |
| 104 | + Payload = JsonSerializer.Serialize(lambdaPayload, _jsonOptions) |
| 105 | + }; |
| 106 | + |
| 107 | + _logger.LogInformation("Invoking Lambda function {functionName} function with {messageCount} messages", _config.FunctionName, lambdaPayload.Records.Count); |
| 108 | + var lambdaResponse = await _lambdaClient.InvokeAsync(invokeRequest, _config.LambdaRuntimeApi); |
| 109 | + |
| 110 | + if (lambdaResponse.FunctionError != null) |
| 111 | + { |
| 112 | + _logger.LogError("Invoking Lambda {function} function with {messageCount} failed with error {errorMessage}", _config.FunctionName, response.Messages.Count, lambdaResponse.FunctionError); |
| 113 | + continue; |
| 114 | + } |
| 115 | + |
| 116 | + if (!_config.DisableMessageDelete) |
| 117 | + { |
| 118 | + List<Message> messagesToDelete; |
| 119 | + if (lambdaResponse.Payload != null && lambdaResponse.Payload.Length > 0) |
| 120 | + { |
| 121 | + var partialResponse = JsonSerializer.Deserialize<SQSBatchResponse>(lambdaResponse.Payload); |
| 122 | + if (partialResponse == null) |
| 123 | + { |
| 124 | + lambdaResponse.Payload.Position = 0; |
| 125 | + using var reader = new StreamReader(lambdaResponse.Payload); |
| 126 | + var payloadString = reader.ReadToEnd(); |
| 127 | + _logger.LogError("Failed to deserialize response from Lambda function into SQSBatchResponse. Response payload:\n{payload}", payloadString); |
| 128 | + continue; |
| 129 | + } |
| 130 | + |
| 131 | + if (partialResponse.BatchItemFailures == null || partialResponse.BatchItemFailures.Count == 0) |
| 132 | + { |
| 133 | + _logger.LogDebug("Partial SQS response received with no failures"); |
| 134 | + messagesToDelete = response.Messages; |
| 135 | + } |
| 136 | + else |
| 137 | + { |
| 138 | + _logger.LogDebug("Partial SQS response received with {count} failures", partialResponse.BatchItemFailures.Count); |
| 139 | + messagesToDelete = new List<Message>(); |
| 140 | + foreach (var message in response.Messages) |
| 141 | + { |
| 142 | + if (!partialResponse.BatchItemFailures.Any(x => string.Equals(x.ItemIdentifier, message.MessageId))) |
| 143 | + { |
| 144 | + messagesToDelete.Add(message); |
| 145 | + } |
| 146 | + } |
| 147 | + } |
| 148 | + } |
| 149 | + else |
| 150 | + { |
| 151 | + _logger.LogDebug("No partial response received. All messages eligible for deletion"); |
| 152 | + messagesToDelete = response.Messages; |
| 153 | + } |
| 154 | + |
| 155 | + if (messagesToDelete.Count > 0) |
| 156 | + { |
| 157 | + var deleteRequest = new DeleteMessageBatchRequest |
| 158 | + { |
| 159 | + QueueUrl = _config.QueueUrl, |
| 160 | + Entries = messagesToDelete.Select(m => new DeleteMessageBatchRequestEntry { Id = m.MessageId, ReceiptHandle = m.ReceiptHandle }).ToList() |
| 161 | + }; |
| 162 | + |
| 163 | + _logger.LogDebug("Deleting {messageCount} messages from queue", deleteRequest.Entries.Count); |
| 164 | + await _sqsClient.DeleteMessageBatchAsync(deleteRequest, stoppingToken); |
| 165 | + } |
| 166 | + } |
| 167 | + } |
| 168 | + catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested) |
| 169 | + { |
| 170 | + return; |
| 171 | + } |
| 172 | + catch (TaskCanceledException) when (stoppingToken.IsCancellationRequested) |
| 173 | + { |
| 174 | + return; |
| 175 | + } |
| 176 | + catch (Exception e) |
| 177 | + { |
| 178 | + _logger.LogWarning(e, "Exception occurred in SQS poller for {queueUrl}: {message}", _config.QueueUrl, e.Message); |
| 179 | + |
| 180 | + // Add a delay before restarting loop in case the exception was a transient error that needs a little time to reset. |
| 181 | + await Task.Delay(3000); |
| 182 | + } |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + /// <summary> |
| 187 | + /// Convert from the SDK's list of messages to the Lambda event's SQS message type. |
| 188 | + /// </summary> |
| 189 | + /// <param name="messages">List of messages using the SDK's .NET type</param> |
| 190 | + /// <param name="awsRegion">The aws region the messages came from.</param> |
| 191 | + /// <param name="queueArn">The SQS queue arn the messages came from.</param> |
| 192 | + /// <returns>List of messages using the Lambda event's .NET type.</returns> |
| 193 | + internal static List<SQSEvent.SQSMessage> ConvertToLambdaMessages(List<Message> messages, string awsRegion, string queueArn) |
| 194 | + { |
| 195 | + return messages.Select(m => ConvertToLambdaMessage(m, awsRegion, queueArn)).ToList(); |
| 196 | + } |
| 197 | + |
| 198 | + /// <summary> |
| 199 | + /// Convert from the SDK's SQS message to the Lambda event's SQS message type. |
| 200 | + /// </summary> |
| 201 | + /// <param name="message">Message using the SDK's .NET type</param> |
| 202 | + /// <param name="awsRegion">The aws region the message came from.</param> |
| 203 | + /// <param name="queueArn">The SQS queue arn the message came from.</param> |
| 204 | + /// <returns>Messages using the Lambda event's .NET type.</returns> |
| 205 | + internal static SQSEvent.SQSMessage ConvertToLambdaMessage(Message message, string awsRegion, string queueArn) |
| 206 | + { |
| 207 | + var lambdaMessage = new SQSEvent.SQSMessage |
| 208 | + { |
| 209 | + AwsRegion = awsRegion, |
| 210 | + Body = message.Body, |
| 211 | + EventSource = "aws:sqs", |
| 212 | + EventSourceArn = queueArn, |
| 213 | + Md5OfBody = message.MD5OfBody, |
| 214 | + Md5OfMessageAttributes = message.MD5OfMessageAttributes, |
| 215 | + MessageId = message.MessageId, |
| 216 | + ReceiptHandle = message.ReceiptHandle, |
| 217 | + }; |
| 218 | + |
| 219 | + if (message.MessageAttributes != null && message.MessageAttributes.Count > 0) |
| 220 | + { |
| 221 | + lambdaMessage.MessageAttributes = new Dictionary<string, SQSEvent.MessageAttribute>(); |
| 222 | + foreach (var kvp in message.MessageAttributes) |
| 223 | + { |
| 224 | + var lambdaAttribute = new SQSEvent.MessageAttribute |
| 225 | + { |
| 226 | + DataType = kvp.Value.DataType, |
| 227 | + StringValue = kvp.Value.StringValue, |
| 228 | + BinaryValue = kvp.Value.BinaryValue |
| 229 | + }; |
| 230 | + |
| 231 | + lambdaMessage.MessageAttributes.Add(kvp.Key, lambdaAttribute); |
| 232 | + } |
| 233 | + } |
| 234 | + |
| 235 | + if (message.Attributes != null && message.Attributes.Count > 0) |
| 236 | + { |
| 237 | + lambdaMessage.Attributes = message.Attributes; |
| 238 | + } |
| 239 | + |
| 240 | + return lambdaMessage; |
| 241 | + } |
| 242 | +} |
0 commit comments