11import unittest
22from unittest .mock import patch , MagicMock
33
4- from datadog_lambda .dsm import set_dsm_context , _dsm_set_sqs_context
4+ from datadog_lambda .dsm import (
5+ set_dsm_context ,
6+ _dsm_set_sqs_context ,
7+ _dsm_set_sns_context ,
8+ )
59from datadog_lambda .trigger import EventTypes , _EventSource
610
711
8- class TestDsmSQSContext (unittest .TestCase ):
12+ class TestDSMContext (unittest .TestCase ):
913 def setUp (self ):
1014 patcher = patch ("datadog_lambda.dsm._dsm_set_sqs_context" )
1115 self .mock_dsm_set_sqs_context = patcher .start ()
1216 self .addCleanup (patcher .stop )
1317
18+ patcher = patch ("datadog_lambda.dsm._dsm_set_sns_context" )
19+ self .mock_dsm_set_sns_context = patcher .start ()
20+ self .addCleanup (patcher .stop )
21+
1422 patcher = patch ("ddtrace.internal.datastreams.data_streams_processor" )
1523 self .mock_data_streams_processor = patcher .start ()
1624 self .addCleanup (patcher .stop )
@@ -27,6 +35,13 @@ def setUp(self):
2735 self .mock_calculate_sqs_payload_size .return_value = 100
2836 self .addCleanup (patcher .stop )
2937
38+ patcher = patch (
39+ "ddtrace.internal.datastreams.botocore.calculate_sns_payload_size"
40+ )
41+ self .mock_calculate_sns_payload_size = patcher .start ()
42+ self .mock_calculate_sns_payload_size .return_value = 150
43+ self .addCleanup (patcher .stop )
44+
3045 patcher = patch ("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode" )
3146 self .mock_dsm_pathway_codec_decode = patcher .start ()
3247 self .addCleanup (patcher .stop )
@@ -110,3 +125,81 @@ def test_sqs_multiple_records_process_each_record(self):
110125 self .assertIn (f"topic:{ expected_arns [i ]} " , tags )
111126 self .assertIn ("type:sqs" , tags )
112127 self .assertEqual (kwargs ["payload_size" ], 100 )
128+
129+ def test_sns_event_with_no_records_does_nothing (self ):
130+ """Test that events where Records is None don't trigger DSM processing"""
131+ events_with_no_records = [
132+ {},
133+ {"Records" : None },
134+ {"someOtherField" : "value" },
135+ ]
136+
137+ for event in events_with_no_records :
138+ _dsm_set_sns_context (event )
139+ self .mock_data_streams_processor .assert_not_called ()
140+
141+ def test_sns_event_triggers_dsm_sns_context (self ):
142+ """Test that SNS event sources trigger the SNS-specific DSM context function"""
143+ sns_event = {
144+ "Records" : [
145+ {
146+ "EventSource" : "aws:sns" ,
147+ "Sns" : {
148+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:my-topic" ,
149+ "Message" : "Hello from SNS!" ,
150+ },
151+ }
152+ ]
153+ }
154+
155+ event_source = _EventSource (EventTypes .SNS )
156+ set_dsm_context (sns_event , event_source )
157+
158+ self .mock_dsm_set_sns_context .assert_called_once_with (sns_event )
159+
160+ def test_sns_multiple_records_process_each_record (self ):
161+ """Test that each record in an SNS event gets processed individually"""
162+ multi_record_event = {
163+ "Records" : [
164+ {
165+ "Sns" : {
166+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic1" ,
167+ "Message" : "Message 1" ,
168+ }
169+ },
170+ {
171+ "Sns" : {
172+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic2" ,
173+ "Message" : "Message 2" ,
174+ }
175+ },
176+ {
177+ "Sns" : {
178+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic3" ,
179+ "Message" : "Message 3" ,
180+ }
181+ },
182+ ]
183+ }
184+
185+ mock_context = MagicMock ()
186+ self .mock_dsm_pathway_codec_decode .return_value = mock_context
187+
188+ _dsm_set_sns_context (multi_record_event )
189+
190+ self .assertEqual (mock_context .set_checkpoint .call_count , 3 )
191+
192+ calls = mock_context .set_checkpoint .call_args_list
193+ expected_arns = [
194+ "arn:aws:sns:us-east-1:123456789012:topic1" ,
195+ "arn:aws:sns:us-east-1:123456789012:topic2" ,
196+ "arn:aws:sns:us-east-1:123456789012:topic3" ,
197+ ]
198+
199+ for i , call in enumerate (calls ):
200+ args , kwargs = call
201+ tags = args [0 ]
202+ self .assertIn ("direction:in" , tags )
203+ self .assertIn (f"topic:{ expected_arns [i ]} " , tags )
204+ self .assertIn ("type:sns" , tags )
205+ self .assertEqual (kwargs ["payload_size" ], 150 )
0 commit comments