11import unittest
22import json
3- from unittest .mock import patch , MagicMock
3+ from unittest .mock import patch
44
55from datadog_lambda .dsm import (
66 set_dsm_context ,
@@ -16,24 +16,19 @@ def setUp(self):
1616 self .mock_dsm_set_sqs_context = patcher .start ()
1717 self .addCleanup (patcher .stop )
1818
19- patcher = patch ("ddtrace.internal.datastreams.data_streams_processor" )
20- self .mock_data_streams_processor = patcher .start ()
21- self .addCleanup (patcher .stop )
22-
2319 patcher = patch ("ddtrace.internal.datastreams.botocore.get_datastreams_context" )
2420 self .mock_get_datastreams_context = patcher .start ()
2521 self .mock_get_datastreams_context .return_value = {}
2622 self .addCleanup (patcher .stop )
2723
28- patcher = patch (
29- "ddtrace.internal.datastreams.botocore.calculate_sqs_payload_size"
30- )
31- self .mock_calculate_sqs_payload_size = patcher .start ()
32- self .mock_calculate_sqs_payload_size .return_value = 100
24+ # Patch set_consume_checkpoint for testing DSM functionality
25+ patcher = patch ("ddtrace.data_streams.set_consume_checkpoint" )
26+ self .mock_set_consume_checkpoint = patcher .start ()
3327 self .addCleanup (patcher .stop )
3428
35- patcher = patch ("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode" )
36- self .mock_dsm_pathway_codec_decode = patcher .start ()
29+ # Patch _get_dsm_context_from_lambda for testing DSM context extraction
30+ patcher = patch ("datadog_lambda.dsm._get_dsm_context_from_lambda" )
31+ self .mock_get_dsm_context_from_lambda = patcher .start ()
3732 self .addCleanup (patcher .stop )
3833
3934 def test_non_sqs_event_source_does_nothing (self ):
@@ -56,7 +51,8 @@ def test_sqs_event_with_no_records_does_nothing(self):
5651
5752 for event in events_with_no_records :
5853 _dsm_set_sqs_context (event )
59- self .mock_data_streams_processor .assert_not_called ()
54+ # Should not call set_consume_checkpoint for events without records
55+ self .mock_set_consume_checkpoint .assert_not_called ()
6056
6157 def test_sqs_event_triggers_dsm_sqs_context (self ):
6258 """Test that SQS event sources trigger the SQS-specific DSM context function"""
@@ -82,39 +78,66 @@ def test_sqs_multiple_records_process_each_record(self):
8278 {
8379 "eventSourceARN" : "arn:aws:sqs:us-east-1:123456789012:queue1" ,
8480 "body" : "Message 1" ,
81+ "messageAttributes" : {
82+ "_datadog" : {
83+ "stringValue" : json .dumps ({"dd-pathway-ctx-base64" : "context1" }),
84+ "dataType" : "String" ,
85+ }
86+ },
8587 },
8688 {
8789 "eventSourceARN" : "arn:aws:sqs:us-east-1:123456789012:queue2" ,
8890 "body" : "Message 2" ,
91+ "messageAttributes" : {
92+ "_datadog" : {
93+ "stringValue" : json .dumps ({"dd-pathway-ctx-base64" : "context2" }),
94+ "dataType" : "String" ,
95+ }
96+ },
8997 },
9098 {
9199 "eventSourceARN" : "arn:aws:sqs:us-east-1:123456789012:queue3" ,
92100 "body" : "Message 3" ,
101+ "messageAttributes" : {
102+ "_datadog" : {
103+ "stringValue" : json .dumps ({"dd-pathway-ctx-base64" : "context3" }),
104+ "dataType" : "String" ,
105+ }
106+ },
93107 },
94108 ]
95109 }
96110
97- mock_context = MagicMock ()
98- self .mock_dsm_pathway_codec_decode .return_value = mock_context
111+ self .mock_get_dsm_context_from_lambda .side_effect = [
112+ {"dd-pathway-ctx-base64" : "context1" },
113+ {"dd-pathway-ctx-base64" : "context2" },
114+ {"dd-pathway-ctx-base64" : "context3" },
115+ ]
99116
100117 _dsm_set_sqs_context (multi_record_event )
101118
102- self .assertEqual (mock_context . set_checkpoint .call_count , 3 )
119+ self .assertEqual (self . mock_set_consume_checkpoint .call_count , 3 )
103120
104- calls = mock_context . set_checkpoint .call_args_list
121+ calls = self . mock_set_consume_checkpoint .call_args_list
105122 expected_arns = [
106123 "arn:aws:sqs:us-east-1:123456789012:queue1" ,
107124 "arn:aws:sqs:us-east-1:123456789012:queue2" ,
108125 "arn:aws:sqs:us-east-1:123456789012:queue3" ,
109126 ]
127+ expected_contexts = ["context1" , "context2" , "context3" ]
110128
111129 for i , call in enumerate (calls ):
112130 args , kwargs = call
113- tags = args [0 ]
114- self .assertIn ("direction:in" , tags )
115- self .assertIn (f"topic:{ expected_arns [i ]} " , tags )
116- self .assertIn ("type:sqs" , tags )
117- self .assertEqual (kwargs ["payload_size" ], 100 )
131+ service_type = args [0 ]
132+ arn = args [1 ]
133+ carrier_get_func = args [2 ]
134+
135+ self .assertEqual (service_type , "sqs" )
136+
137+ self .assertEqual (arn , expected_arns [i ])
138+
139+ pathway_ctx = carrier_get_func ("dd-pathway-ctx-base64" )
140+ self .assertEqual (pathway_ctx , expected_contexts [i ])
118141
119142
120143class TestGetDSMContext (unittest .TestCase ):
0 commit comments