66from datadog_lambda .dsm import (
77 set_dsm_context ,
88 _dsm_set_sqs_context ,
9+ _dsm_set_sns_context ,
910 _get_dsm_context_from_lambda ,
1011)
1112from datadog_lambda .trigger import EventTypes , _EventSource
@@ -17,6 +18,10 @@ def setUp(self):
1718 self .mock_dsm_set_sqs_context = patcher .start ()
1819 self .addCleanup (patcher .stop )
1920
21+ patcher = patch ("datadog_lambda.dsm._dsm_set_sns_context" )
22+ self .mock_dsm_set_sns_context = patcher .start ()
23+ self .addCleanup (patcher .stop )
24+
2025 patcher = patch ("ddtrace.internal.datastreams.data_streams_processor" )
2126 self .mock_data_streams_processor = patcher .start ()
2227 self .addCleanup (patcher .stop )
@@ -33,6 +38,13 @@ def setUp(self):
3338 self .mock_calculate_sqs_payload_size .return_value = 100
3439 self .addCleanup (patcher .stop )
3540
41+ patcher = patch (
42+ "ddtrace.internal.datastreams.botocore.calculate_sns_payload_size"
43+ )
44+ self .mock_calculate_sns_payload_size = patcher .start ()
45+ self .mock_calculate_sns_payload_size .return_value = 150
46+ self .addCleanup (patcher .stop )
47+
3648 patcher = patch ("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode" )
3749 self .mock_dsm_pathway_codec_decode = patcher .start ()
3850 self .addCleanup (patcher .stop )
@@ -117,6 +129,84 @@ def test_sqs_multiple_records_process_each_record(self):
117129 self .assertIn ("type:sqs" , tags )
118130 self .assertEqual (kwargs ["payload_size" ], 100 )
119131
132+ def test_sns_event_with_no_records_does_nothing (self ):
133+ """Test that events where Records is None don't trigger DSM processing"""
134+ events_with_no_records = [
135+ {},
136+ {"Records" : None },
137+ {"someOtherField" : "value" },
138+ ]
139+
140+ for event in events_with_no_records :
141+ _dsm_set_sns_context (event )
142+ self .mock_data_streams_processor .assert_not_called ()
143+
144+ def test_sns_event_triggers_dsm_sns_context (self ):
145+ """Test that SNS event sources trigger the SNS-specific DSM context function"""
146+ sns_event = {
147+ "Records" : [
148+ {
149+ "EventSource" : "aws:sns" ,
150+ "Sns" : {
151+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:my-topic" ,
152+ "Message" : "Hello from SNS!" ,
153+ },
154+ }
155+ ]
156+ }
157+
158+ event_source = _EventSource (EventTypes .SNS )
159+ set_dsm_context (sns_event , event_source )
160+
161+ self .mock_dsm_set_sns_context .assert_called_once_with (sns_event )
162+
163+ def test_sns_multiple_records_process_each_record (self ):
164+ """Test that each record in an SNS event gets processed individually"""
165+ multi_record_event = {
166+ "Records" : [
167+ {
168+ "Sns" : {
169+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic1" ,
170+ "Message" : "Message 1" ,
171+ }
172+ },
173+ {
174+ "Sns" : {
175+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic2" ,
176+ "Message" : "Message 2" ,
177+ }
178+ },
179+ {
180+ "Sns" : {
181+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic3" ,
182+ "Message" : "Message 3" ,
183+ }
184+ },
185+ ]
186+ }
187+
188+ mock_context = MagicMock ()
189+ self .mock_dsm_pathway_codec_decode .return_value = mock_context
190+
191+ _dsm_set_sns_context (multi_record_event )
192+
193+ self .assertEqual (mock_context .set_checkpoint .call_count , 3 )
194+
195+ calls = mock_context .set_checkpoint .call_args_list
196+ expected_arns = [
197+ "arn:aws:sns:us-east-1:123456789012:topic1" ,
198+ "arn:aws:sns:us-east-1:123456789012:topic2" ,
199+ "arn:aws:sns:us-east-1:123456789012:topic3" ,
200+ ]
201+
202+ for i , call in enumerate (calls ):
203+ args , kwargs = call
204+ tags = args [0 ]
205+ self .assertIn ("direction:in" , tags )
206+ self .assertIn (f"topic:{ expected_arns [i ]} " , tags )
207+ self .assertIn ("type:sns" , tags )
208+ self .assertEqual (kwargs ["payload_size" ], 150 )
209+
120210
121211class TestGetDSMContext (unittest .TestCase ):
122212 def test_sqs_to_lambda_string_value_format (self ):
@@ -203,7 +293,8 @@ def test_sns_to_lambda_format(self):
203293 assert result ["dd-pathway-ctx" ] == "test-pathway-ctx"
204294
205295 def test_sns_to_sqs_to_lambda_binary_value_format (self ):
206- """Test format: message.messageAttributes._datadog.binaryValue.decode() (SNS -> SQS -> lambda, raw)"""
296+ """Test format: message.messageAttributes._datadog.binaryValue.decode()
297+ (SNS -> SQS -> lambda, raw)"""
207298 trace_context = {
208299 "x-datadog-trace-id" : "777666555" ,
209300 "x-datadog-parent-id" : "444333222" ,
@@ -233,7 +324,8 @@ def test_sns_to_sqs_to_lambda_binary_value_format(self):
233324 assert result ["dd-pathway-ctx" ] == "test-pathway-ctx"
234325
235326 def test_sns_to_sqs_to_lambda_body_format (self ):
236- """Test format: message.body.MessageAttributes._datadog.Value.decode() (SNS -> SQS -> lambda)"""
327+ """Test format: message.body.MessageAttributes._datadog.Value.decode()
328+ (SNS -> SQS -> lambda)"""
237329 trace_context = {
238330 "x-datadog-trace-id" : "123987456" ,
239331 "x-datadog-parent-id" : "654321987" ,
0 commit comments