diff --git a/datadog_lambda/dsm.py b/datadog_lambda/dsm.py index 427f5e47..741be9ac 100644 --- a/datadog_lambda/dsm.py +++ b/datadog_lambda/dsm.py @@ -3,19 +3,28 @@ def set_dsm_context(event, event_source): - if event_source.equals(EventTypes.SQS): _dsm_set_sqs_context(event) + elif event_source.equals(EventTypes.SNS): + _dsm_set_sns_context(event) -def _dsm_set_sqs_context(event): +def _dsm_set_context_helper( + event, service_type, arn_extractor, payload_size_calculator +): + """ + Common helper function for setting DSM context. + + Args: + event: The Lambda event containing records + service_type: The service type string (example: sqs', 'sns') + arn_extractor: Function to extract the ARN from the record + payload_size_calculator: Function to calculate payload size + """ from datadog_lambda.wrapper import format_err_with_traceback from ddtrace.internal.datastreams import data_streams_processor from ddtrace.internal.datastreams.processor import DsmPathwayCodec - from ddtrace.internal.datastreams.botocore import ( - get_datastreams_context, - calculate_sqs_payload_size, - ) + from ddtrace.internal.datastreams.botocore import get_datastreams_context records = event.get("Records") if records is None: @@ -24,15 +33,41 @@ def _dsm_set_sqs_context(event): for record in records: try: - queue_arn = record.get("eventSourceARN", "") - - contextjson = get_datastreams_context(record) - payload_size = calculate_sqs_payload_size(record) + arn = arn_extractor(record) + context_json = get_datastreams_context(record) + payload_size = payload_size_calculator(record, context_json) - ctx = DsmPathwayCodec.decode(contextjson, processor) + ctx = DsmPathwayCodec.decode(context_json, processor) ctx.set_checkpoint( - ["direction:in", f"topic:{queue_arn}", "type:sqs"], + ["direction:in", f"topic:{arn}", f"type:{service_type}"], payload_size=payload_size, ) except Exception as e: logger.error(format_err_with_traceback(e)) + + +def _dsm_set_sns_context(event): + from ddtrace.internal.datastreams.botocore import calculate_sns_payload_size + + def sns_payload_calculator(record, context_json): + return calculate_sns_payload_size(record, context_json) + + def sns_arn_extractor(record): + sns_data = record.get("Sns") + if not sns_data: + return "" + return sns_data.get("TopicArn", "") + + _dsm_set_context_helper(event, "sns", sns_arn_extractor, sns_payload_calculator) + + +def _dsm_set_sqs_context(event): + from ddtrace.internal.datastreams.botocore import calculate_sqs_payload_size + + def sqs_payload_calculator(record, context_json): + return calculate_sqs_payload_size(record) + + def sqs_arn_extractor(record): + return record.get("eventSourceARN", "") + + _dsm_set_context_helper(event, "sqs", sqs_arn_extractor, sqs_payload_calculator) diff --git a/tests/test_dsm.py b/tests/test_dsm.py index 544212d8..30b82a96 100644 --- a/tests/test_dsm.py +++ b/tests/test_dsm.py @@ -1,16 +1,24 @@ import unittest from unittest.mock import patch, MagicMock -from datadog_lambda.dsm import set_dsm_context, _dsm_set_sqs_context +from datadog_lambda.dsm import ( + set_dsm_context, + _dsm_set_sqs_context, + _dsm_set_sns_context, +) from datadog_lambda.trigger import EventTypes, _EventSource -class TestDsmSQSContext(unittest.TestCase): +class TestDSMContext(unittest.TestCase): def setUp(self): patcher = patch("datadog_lambda.dsm._dsm_set_sqs_context") self.mock_dsm_set_sqs_context = patcher.start() self.addCleanup(patcher.stop) + patcher = patch("datadog_lambda.dsm._dsm_set_sns_context") + self.mock_dsm_set_sns_context = patcher.start() + self.addCleanup(patcher.stop) + patcher = patch("ddtrace.internal.datastreams.data_streams_processor") self.mock_data_streams_processor = patcher.start() self.addCleanup(patcher.stop) @@ -27,6 +35,13 @@ def setUp(self): self.mock_calculate_sqs_payload_size.return_value = 100 self.addCleanup(patcher.stop) + patcher = patch( + "ddtrace.internal.datastreams.botocore.calculate_sns_payload_size" + ) + self.mock_calculate_sns_payload_size = patcher.start() + self.mock_calculate_sns_payload_size.return_value = 150 + self.addCleanup(patcher.stop) + patcher = patch("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode") self.mock_dsm_pathway_codec_decode = patcher.start() self.addCleanup(patcher.stop) @@ -110,3 +125,81 @@ def test_sqs_multiple_records_process_each_record(self): self.assertIn(f"topic:{expected_arns[i]}", tags) self.assertIn("type:sqs", tags) self.assertEqual(kwargs["payload_size"], 100) + + def test_sns_event_with_no_records_does_nothing(self): + """Test that events where Records is None don't trigger DSM processing""" + events_with_no_records = [ + {}, + {"Records": None}, + {"someOtherField": "value"}, + ] + + for event in events_with_no_records: + _dsm_set_sns_context(event) + self.mock_data_streams_processor.assert_not_called() + + def test_sns_event_triggers_dsm_sns_context(self): + """Test that SNS event sources trigger the SNS-specific DSM context function""" + sns_event = { + "Records": [ + { + "EventSource": "aws:sns", + "Sns": { + "TopicArn": "arn:aws:sns:us-east-1:123456789012:my-topic", + "Message": "Hello from SNS!", + }, + } + ] + } + + event_source = _EventSource(EventTypes.SNS) + set_dsm_context(sns_event, event_source) + + self.mock_dsm_set_sns_context.assert_called_once_with(sns_event) + + def test_sns_multiple_records_process_each_record(self): + """Test that each record in an SNS event gets processed individually""" + multi_record_event = { + "Records": [ + { + "Sns": { + "TopicArn": "arn:aws:sns:us-east-1:123456789012:topic1", + "Message": "Message 1", + } + }, + { + "Sns": { + "TopicArn": "arn:aws:sns:us-east-1:123456789012:topic2", + "Message": "Message 2", + } + }, + { + "Sns": { + "TopicArn": "arn:aws:sns:us-east-1:123456789012:topic3", + "Message": "Message 3", + } + }, + ] + } + + mock_context = MagicMock() + self.mock_dsm_pathway_codec_decode.return_value = mock_context + + _dsm_set_sns_context(multi_record_event) + + self.assertEqual(mock_context.set_checkpoint.call_count, 3) + + calls = mock_context.set_checkpoint.call_args_list + expected_arns = [ + "arn:aws:sns:us-east-1:123456789012:topic1", + "arn:aws:sns:us-east-1:123456789012:topic2", + "arn:aws:sns:us-east-1:123456789012:topic3", + ] + + for i, call in enumerate(calls): + args, kwargs = call + tags = args[0] + self.assertIn("direction:in", tags) + self.assertIn(f"topic:{expected_arns[i]}", tags) + self.assertIn("type:sns", tags) + self.assertEqual(kwargs["payload_size"], 150)