Skip to content

Commit 05d7a18

Browse files
add sns -> lambda support
1 parent 45ed35f commit 05d7a18

File tree

4 files changed

+202
-4
lines changed

4 files changed

+202
-4
lines changed

datadog_lambda/dsm.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22
import json
3+
import base64
4+
35
from datadog_lambda.trigger import EventTypes
46

57
logger = logging.getLogger(__name__)
@@ -8,6 +10,8 @@
810
def set_dsm_context(event, event_source):
911
if event_source.equals(EventTypes.SQS):
1012
_dsm_set_sqs_context(event)
13+
elif event_source.equals(EventTypes.SNS):
14+
_dsm_set_sns_context(event)
1115

1216

1317
def _dsm_set_sqs_context(event):
@@ -28,13 +32,43 @@ def _dsm_set_sqs_context(event):
2832
set_consume_checkpoint("sqs", arn, carrier_get)
2933

3034

35+
def _dsm_set_sns_context(event):
36+
from ddtrace.data_streams import set_consume_checkpoint
37+
38+
records = event.get("Records")
39+
if records is None:
40+
return
41+
42+
for record in records:
43+
sns_data = record.get("Sns")
44+
if not sns_data:
45+
return
46+
arn = sns_data.get("TopicArn", "")
47+
context_json = _get_dsm_context_from_lambda(sns_data)
48+
if not context_json:
49+
logger.debug("DataStreams skipped lambda message: %r", sns_data)
50+
return None
51+
52+
carrier_get = _create_carrier_get(context_json)
53+
set_consume_checkpoint("sns", arn, carrier_get)
54+
55+
3156
def _get_dsm_context_from_lambda(message):
3257
"""
3358
Lambda-specific message formats:
3459
- message.messageAttributes._datadog.stringValue (SQS -> lambda)
60+
- message.Sns.MessageAttributes._datadog.Value.decode() (SNS -> lambda)
3561
"""
3662
context_json = None
37-
message_attributes = message.get("messageAttributes")
63+
message_body = message
64+
65+
if "Sns" in message:
66+
message_body = message["Sns"]
67+
68+
message_attributes = message_body.get("MessageAttributes") or message_body.get(
69+
"messageAttributes"
70+
)
71+
3872
if not message_attributes:
3973
logger.debug("DataStreams skipped lambda message: %r", message)
4074
return None
@@ -45,7 +79,11 @@ def _get_dsm_context_from_lambda(message):
4579

4680
datadog_attr = message_attributes["_datadog"]
4781

48-
if "stringValue" in datadog_attr:
82+
if message_body.get("Type") == "Notification":
83+
# SNS -> lambda notification
84+
if datadog_attr.get("Type") == "Binary":
85+
context_json = json.loads(base64.b64decode(datadog_attr["Value"]).decode())
86+
elif "stringValue" in datadog_attr:
4987
# SQS -> lambda
5088
context_json = json.loads(datadog_attr["stringValue"])
5189
else:
4.34 MB
Binary file not shown.
4.31 MB
Binary file not shown.

tests/test_dsm.py

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import unittest
2+
import base64
23
import json
34
from unittest.mock import patch
45

56
from datadog_lambda.dsm import (
67
set_dsm_context,
78
_dsm_set_sqs_context,
9+
_dsm_set_sns_context,
810
_get_dsm_context_from_lambda,
911
)
1012
from datadog_lambda.trigger import EventTypes, _EventSource
@@ -16,14 +18,18 @@ def setUp(self):
1618
self.mock_dsm_set_sqs_context = patcher.start()
1719
self.addCleanup(patcher.stop)
1820

19-
# Patch set_consume_checkpoint for testing DSM functionality
2021
patcher = patch("ddtrace.data_streams.set_consume_checkpoint")
2122
self.mock_set_consume_checkpoint = patcher.start()
2223
self.addCleanup(patcher.stop)
2324

24-
# Patch _get_dsm_context_from_lambda for testing DSM context extraction
2525
patcher = patch("datadog_lambda.dsm._get_dsm_context_from_lambda")
2626
self.mock_get_dsm_context_from_lambda = patcher.start()
27+
patcher = patch("datadog_lambda.dsm._dsm_set_sns_context")
28+
self.mock_dsm_set_sns_context = patcher.start()
29+
self.addCleanup(patcher.stop)
30+
31+
patcher = patch("ddtrace.internal.datastreams.data_streams_processor")
32+
self.mock_data_streams_processor = patcher.start()
2733
self.addCleanup(patcher.stop)
2834

2935
def test_non_sqs_event_source_does_nothing(self):
@@ -134,6 +140,123 @@ def test_sqs_multiple_records_process_each_record(self):
134140
pathway_ctx = carrier_get_func("dd-pathway-ctx-base64")
135141
self.assertEqual(pathway_ctx, expected_contexts[i])
136142

143+
def test_sns_event_with_no_records_does_nothing(self):
144+
"""Test that events where Records is None don't trigger DSM processing"""
145+
events_with_no_records = [
146+
{},
147+
{"Records": None},
148+
{"someOtherField": "value"},
149+
]
150+
151+
for event in events_with_no_records:
152+
_dsm_set_sns_context(event)
153+
self.mock_set_consume_checkpoint.assert_not_called()
154+
155+
def test_sns_event_triggers_dsm_sns_context(self):
156+
"""Test that SNS event sources trigger the SNS-specific DSM context function"""
157+
sns_event = {
158+
"Records": [
159+
{
160+
"EventSource": "aws:sns",
161+
"Sns": {
162+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:my-topic",
163+
"Message": "Hello from SNS!",
164+
},
165+
}
166+
]
167+
}
168+
169+
event_source = _EventSource(EventTypes.SNS)
170+
set_dsm_context(sns_event, event_source)
171+
172+
self.mock_dsm_set_sns_context.assert_called_once_with(sns_event)
173+
174+
def test_sns_multiple_records_process_each_record(self):
175+
"""Test that each record in an SNS event gets processed individually"""
176+
multi_record_event = {
177+
"Records": [
178+
{
179+
"EventSource": "aws:sns",
180+
"Sns": {
181+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:topic1",
182+
"Message": "Message 1",
183+
"MessageAttributes": {
184+
"_datadog": {
185+
"Type": "Binary",
186+
"Value": base64.b64encode(
187+
json.dumps({"dd-pathway-ctx-base64": "context1"})
188+
.encode("utf-8")
189+
).decode("utf-8")
190+
}
191+
},
192+
}
193+
},
194+
{
195+
"EventSource": "aws:sns",
196+
"Sns": {
197+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:topic2",
198+
"Message": "Message 2",
199+
"MessageAttributes": {
200+
"_datadog": {
201+
"Type": "Binary",
202+
"Value": base64.b64encode(
203+
json.dumps({"dd-pathway-ctx-base64": "context2"})
204+
.encode("utf-8")
205+
).decode("utf-8")
206+
}
207+
},
208+
}
209+
},
210+
{
211+
"EventSource": "aws:sns",
212+
"Sns": {
213+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:topic3",
214+
"Message": "Message 3",
215+
"MessageAttributes": {
216+
"_datadog": {
217+
"Type": "Binary",
218+
"Value": base64.b64encode(
219+
json.dumps({"dd-pathway-ctx-base64": "context3"})
220+
.encode("utf-8")
221+
).decode("utf-8")
222+
}
223+
},
224+
}
225+
},
226+
]
227+
}
228+
229+
self.mock_get_dsm_context_from_lambda.side_effect = [
230+
{"dd-pathway-ctx-base64": "context1"},
231+
{"dd-pathway-ctx-base64": "context2"},
232+
{"dd-pathway-ctx-base64": "context3"},
233+
]
234+
235+
_dsm_set_sns_context(multi_record_event)
236+
237+
self.assertEqual(self.mock_set_consume_checkpoint.call_count, 3)
238+
239+
calls = self.mock_set_consume_checkpoint.call_args_list
240+
expected_arns = [
241+
"arn:aws:sns:us-east-1:123456789012:topic1",
242+
"arn:aws:sns:us-east-1:123456789012:topic2",
243+
"arn:aws:sns:us-east-1:123456789012:topic3",
244+
]
245+
expected_contexts = ["context1", "context2", "context3"]
246+
247+
for i, call in enumerate(calls):
248+
args, kwargs = call
249+
service_type = args[0]
250+
arn = args[1]
251+
carrier_get_func = args[2]
252+
253+
self.assertEqual(service_type, "sns")
254+
255+
self.assertEqual(arn, expected_arns[i])
256+
257+
pathway_ctx = carrier_get_func("dd-pathway-ctx-base64")
258+
self.assertEqual(pathway_ctx, expected_contexts[i])
259+
137260

138261
class TestGetDSMContext(unittest.TestCase):
139262
def test_sqs_to_lambda_string_value_format(self):
@@ -182,6 +305,43 @@ def test_sqs_to_lambda_string_value_format(self):
182305
assert result["x-datadog-parent-id"] == "321987654"
183306
assert result["dd-pathway-ctx"] == "test-pathway-ctx"
184307

308+
def test_sns_to_lambda_format(self):
309+
"""Test format: message.Sns.MessageAttributes._datadog.Value.decode() (SNS -> lambda)"""
310+
trace_context = {
311+
"x-datadog-trace-id": "111111111",
312+
"x-datadog-parent-id": "222222222",
313+
"dd-pathway-ctx": "test-pathway-ctx",
314+
}
315+
binary_data = base64.b64encode(
316+
json.dumps(trace_context).encode("utf-8")
317+
).decode("utf-8")
318+
319+
sns_lambda_record = {
320+
"EventSource": "aws:sns",
321+
"EventSubscriptionArn": (
322+
"arn:aws:sns:us-east-1:123456789012:sns-topic:12345678-1234-1234-1234-123456789012"
323+
),
324+
"Sns": {
325+
"Type": "Notification",
326+
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
327+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:sns-topic",
328+
"Subject": "Test Subject",
329+
"Message": "Hello from SNS!",
330+
"Timestamp": "2023-01-01T12:00:00.000Z",
331+
"MessageAttributes": {
332+
"_datadog": {"Type": "Binary", "Value": binary_data}
333+
},
334+
},
335+
}
336+
337+
result = _get_dsm_context_from_lambda(sns_lambda_record)
338+
339+
assert result is not None
340+
assert result == trace_context
341+
assert result["x-datadog-trace-id"] == "111111111"
342+
assert result["x-datadog-parent-id"] == "222222222"
343+
assert result["dd-pathway-ctx"] == "test-pathway-ctx"
344+
185345
def test_no_message_attributes(self):
186346
"""Test message without MessageAttributes returns None."""
187347
message = {

0 commit comments

Comments
 (0)