11import unittest
2+ import base64
23import json
34from unittest .mock import patch , MagicMock
45
56from datadog_lambda .dsm import (
67 set_dsm_context ,
78 _dsm_set_sqs_context ,
9+ _dsm_set_sns_context ,
810 _get_dsm_context_from_lambda ,
911)
1012from datadog_lambda .trigger import EventTypes , _EventSource
@@ -16,6 +18,10 @@ def setUp(self):
1618 self .mock_dsm_set_sqs_context = patcher .start ()
1719 self .addCleanup (patcher .stop )
1820
21+ patcher = patch ("datadog_lambda.dsm._dsm_set_sns_context" )
22+ self .mock_dsm_set_sns_context = patcher .start ()
23+ self .addCleanup (patcher .stop )
24+
1925 patcher = patch ("ddtrace.internal.datastreams.data_streams_processor" )
2026 self .mock_data_streams_processor = patcher .start ()
2127 self .addCleanup (patcher .stop )
@@ -32,6 +38,13 @@ def setUp(self):
3238 self .mock_calculate_sqs_payload_size .return_value = 100
3339 self .addCleanup (patcher .stop )
3440
41+ patcher = patch (
42+ "ddtrace.internal.datastreams.botocore.calculate_sns_payload_size"
43+ )
44+ self .mock_calculate_sns_payload_size = patcher .start ()
45+ self .mock_calculate_sns_payload_size .return_value = 150
46+ self .addCleanup (patcher .stop )
47+
3548 patcher = patch ("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode" )
3649 self .mock_dsm_pathway_codec_decode = patcher .start ()
3750 self .addCleanup (patcher .stop )
@@ -116,6 +129,84 @@ def test_sqs_multiple_records_process_each_record(self):
116129 self .assertIn ("type:sqs" , tags )
117130 self .assertEqual (kwargs ["payload_size" ], 100 )
118131
132+ def test_sns_event_with_no_records_does_nothing (self ):
133+ """Test that events where Records is None don't trigger DSM processing"""
134+ events_with_no_records = [
135+ {},
136+ {"Records" : None },
137+ {"someOtherField" : "value" },
138+ ]
139+
140+ for event in events_with_no_records :
141+ _dsm_set_sns_context (event )
142+ self .mock_data_streams_processor .assert_not_called ()
143+
144+ def test_sns_event_triggers_dsm_sns_context (self ):
145+ """Test that SNS event sources trigger the SNS-specific DSM context function"""
146+ sns_event = {
147+ "Records" : [
148+ {
149+ "EventSource" : "aws:sns" ,
150+ "Sns" : {
151+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:my-topic" ,
152+ "Message" : "Hello from SNS!" ,
153+ },
154+ }
155+ ]
156+ }
157+
158+ event_source = _EventSource (EventTypes .SNS )
159+ set_dsm_context (sns_event , event_source )
160+
161+ self .mock_dsm_set_sns_context .assert_called_once_with (sns_event )
162+
163+ def test_sns_multiple_records_process_each_record (self ):
164+ """Test that each record in an SNS event gets processed individually"""
165+ multi_record_event = {
166+ "Records" : [
167+ {
168+ "Sns" : {
169+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic1" ,
170+ "Message" : "Message 1" ,
171+ }
172+ },
173+ {
174+ "Sns" : {
175+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic2" ,
176+ "Message" : "Message 2" ,
177+ }
178+ },
179+ {
180+ "Sns" : {
181+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic3" ,
182+ "Message" : "Message 3" ,
183+ }
184+ },
185+ ]
186+ }
187+
188+ mock_context = MagicMock ()
189+ self .mock_dsm_pathway_codec_decode .return_value = mock_context
190+
191+ _dsm_set_sns_context (multi_record_event )
192+
193+ self .assertEqual (mock_context .set_checkpoint .call_count , 3 )
194+
195+ calls = mock_context .set_checkpoint .call_args_list
196+ expected_arns = [
197+ "arn:aws:sns:us-east-1:123456789012:topic1" ,
198+ "arn:aws:sns:us-east-1:123456789012:topic2" ,
199+ "arn:aws:sns:us-east-1:123456789012:topic3" ,
200+ ]
201+
202+ for i , call in enumerate (calls ):
203+ args , kwargs = call
204+ tags = args [0 ]
205+ self .assertIn ("direction:in" , tags )
206+ self .assertIn (f"topic:{ expected_arns [i ]} " , tags )
207+ self .assertIn ("type:sns" , tags )
208+ self .assertEqual (kwargs ["payload_size" ], 150 )
209+
119210
120211class TestGetDSMContext (unittest .TestCase ):
121212 def test_sqs_to_lambda_string_value_format (self ):
@@ -164,6 +255,43 @@ def test_sqs_to_lambda_string_value_format(self):
164255 assert result ["x-datadog-parent-id" ] == "321987654"
165256 assert result ["dd-pathway-ctx" ] == "test-pathway-ctx"
166257
258+ def test_sns_to_lambda_format (self ):
259+ """Test format: message.Sns.MessageAttributes._datadog.Value.decode() (SNS -> lambda)"""
260+ trace_context = {
261+ "x-datadog-trace-id" : "111111111" ,
262+ "x-datadog-parent-id" : "222222222" ,
263+ "dd-pathway-ctx" : "test-pathway-ctx" ,
264+ }
265+ binary_data = base64 .b64encode (
266+ json .dumps (trace_context ).encode ("utf-8" )
267+ ).decode ("utf-8" )
268+
269+ sns_lambda_record = {
270+ "EventSource" : "aws:sns" ,
271+ "EventSubscriptionArn" : (
272+ "arn:aws:sns:us-east-1:123456789012:sns-topic:12345678-1234-1234-1234-123456789012"
273+ ),
274+ "Sns" : {
275+ "Type" : "Notification" ,
276+ "MessageId" : "95df01b4-ee98-5cb9-9903-4c221d41eb5e" ,
277+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:sns-topic" ,
278+ "Subject" : "Test Subject" ,
279+ "Message" : "Hello from SNS!" ,
280+ "Timestamp" : "2023-01-01T12:00:00.000Z" ,
281+ "MessageAttributes" : {
282+ "_datadog" : {"Type" : "Binary" , "Value" : binary_data }
283+ },
284+ },
285+ }
286+
287+ result = _get_dsm_context_from_lambda (sns_lambda_record )
288+
289+ assert result is not None
290+ assert result == trace_context
291+ assert result ["x-datadog-trace-id" ] == "111111111"
292+ assert result ["x-datadog-parent-id" ] == "222222222"
293+ assert result ["dd-pathway-ctx" ] == "test-pathway-ctx"
294+
167295 def test_no_message_attributes (self ):
168296 """Test message without MessageAttributes returns None."""
169297 message = {
0 commit comments