Skip to content

Commit 5356cc6

Browse files
fix
1 parent 823a07f commit 5356cc6

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

datadog_lambda/dsm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ def _dsm_set_sqs_context(event):
2424
logger.debug("DataStreams skipped lambda message: %r", record)
2525
return None
2626

27-
def carrier_get(key):
28-
return context_json.get(key)
29-
27+
carrier_get = _create_carrier_get(context_json)
3028
set_consume_checkpoint("sqs", arn, carrier_get)
3129

3230

@@ -54,3 +52,10 @@ def _get_dsm_context_from_lambda(message):
5452
logger.debug("DataStreams did not handle lambda message: %r", message)
5553

5654
return context_json
55+
56+
57+
def _create_carrier_get(context_json):
58+
def carrier_get(key):
59+
return context_json.get(key)
60+
61+
return carrier_get

tests/test_dsm.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import json
3-
from unittest.mock import patch, MagicMock
3+
from unittest.mock import patch
44

55
from datadog_lambda.dsm import (
66
set_dsm_context,
@@ -16,24 +16,19 @@ def setUp(self):
1616
self.mock_dsm_set_sqs_context = patcher.start()
1717
self.addCleanup(patcher.stop)
1818

19-
patcher = patch("ddtrace.internal.datastreams.data_streams_processor")
20-
self.mock_data_streams_processor = patcher.start()
21-
self.addCleanup(patcher.stop)
22-
2319
patcher = patch("ddtrace.internal.datastreams.botocore.get_datastreams_context")
2420
self.mock_get_datastreams_context = patcher.start()
2521
self.mock_get_datastreams_context.return_value = {}
2622
self.addCleanup(patcher.stop)
2723

28-
patcher = patch(
29-
"ddtrace.internal.datastreams.botocore.calculate_sqs_payload_size"
30-
)
31-
self.mock_calculate_sqs_payload_size = patcher.start()
32-
self.mock_calculate_sqs_payload_size.return_value = 100
24+
# Patch set_consume_checkpoint for testing DSM functionality
25+
patcher = patch("ddtrace.data_streams.set_consume_checkpoint")
26+
self.mock_set_consume_checkpoint = patcher.start()
3327
self.addCleanup(patcher.stop)
3428

35-
patcher = patch("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode")
36-
self.mock_dsm_pathway_codec_decode = patcher.start()
29+
# Patch _get_dsm_context_from_lambda for testing DSM context extraction
30+
patcher = patch("datadog_lambda.dsm._get_dsm_context_from_lambda")
31+
self.mock_get_dsm_context_from_lambda = patcher.start()
3732
self.addCleanup(patcher.stop)
3833

3934
def test_non_sqs_event_source_does_nothing(self):
@@ -56,7 +51,8 @@ def test_sqs_event_with_no_records_does_nothing(self):
5651

5752
for event in events_with_no_records:
5853
_dsm_set_sqs_context(event)
59-
self.mock_data_streams_processor.assert_not_called()
54+
# Should not call set_consume_checkpoint for events without records
55+
self.mock_set_consume_checkpoint.assert_not_called()
6056

6157
def test_sqs_event_triggers_dsm_sqs_context(self):
6258
"""Test that SQS event sources trigger the SQS-specific DSM context function"""
@@ -82,39 +78,66 @@ def test_sqs_multiple_records_process_each_record(self):
8278
{
8379
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue1",
8480
"body": "Message 1",
81+
"messageAttributes": {
82+
"_datadog": {
83+
"stringValue": json.dumps({"dd-pathway-ctx-base64": "context1"}),
84+
"dataType": "String",
85+
}
86+
},
8587
},
8688
{
8789
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue2",
8890
"body": "Message 2",
91+
"messageAttributes": {
92+
"_datadog": {
93+
"stringValue": json.dumps({"dd-pathway-ctx-base64": "context2"}),
94+
"dataType": "String",
95+
}
96+
},
8997
},
9098
{
9199
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue3",
92100
"body": "Message 3",
101+
"messageAttributes": {
102+
"_datadog": {
103+
"stringValue": json.dumps({"dd-pathway-ctx-base64": "context3"}),
104+
"dataType": "String",
105+
}
106+
},
93107
},
94108
]
95109
}
96110

97-
mock_context = MagicMock()
98-
self.mock_dsm_pathway_codec_decode.return_value = mock_context
111+
self.mock_get_dsm_context_from_lambda.side_effect = [
112+
{"dd-pathway-ctx-base64": "context1"},
113+
{"dd-pathway-ctx-base64": "context2"},
114+
{"dd-pathway-ctx-base64": "context3"},
115+
]
99116

100117
_dsm_set_sqs_context(multi_record_event)
101118

102-
self.assertEqual(mock_context.set_checkpoint.call_count, 3)
119+
self.assertEqual(self.mock_set_consume_checkpoint.call_count, 3)
103120

104-
calls = mock_context.set_checkpoint.call_args_list
121+
calls = self.mock_set_consume_checkpoint.call_args_list
105122
expected_arns = [
106123
"arn:aws:sqs:us-east-1:123456789012:queue1",
107124
"arn:aws:sqs:us-east-1:123456789012:queue2",
108125
"arn:aws:sqs:us-east-1:123456789012:queue3",
109126
]
127+
expected_contexts = ["context1", "context2", "context3"]
110128

111129
for i, call in enumerate(calls):
112130
args, kwargs = call
113-
tags = args[0]
114-
self.assertIn("direction:in", tags)
115-
self.assertIn(f"topic:{expected_arns[i]}", tags)
116-
self.assertIn("type:sqs", tags)
117-
self.assertEqual(kwargs["payload_size"], 100)
131+
service_type = args[0]
132+
arn = args[1]
133+
carrier_get_func = args[2]
134+
135+
self.assertEqual(service_type, "sqs")
136+
137+
self.assertEqual(arn, expected_arns[i])
138+
139+
pathway_ctx = carrier_get_func("dd-pathway-ctx-base64")
140+
self.assertEqual(pathway_ctx, expected_contexts[i])
118141

119142

120143
class TestGetDSMContext(unittest.TestCase):

0 commit comments

Comments
 (0)