Skip to content

Commit 8acce15

Browse files
added dsm support for sns->lambda and sns->sqs->lambda
1 parent 6beb65d commit 8acce15

File tree

2 files changed

+142
-14
lines changed

2 files changed

+142
-14
lines changed

datadog_lambda/dsm.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,28 @@
33

44

55
def set_dsm_context(event, event_source):
6-
76
if event_source.equals(EventTypes.SQS):
87
_dsm_set_sqs_context(event)
8+
elif event_source.equals(EventTypes.SNS):
9+
_dsm_set_sns_context(event)
910

1011

11-
def _dsm_set_sqs_context(event):
12+
def _dsm_set_context_helper(
13+
event, service_type, arn_extractor, payload_size_calculator
14+
):
15+
"""
16+
Common helper function for setting DSM context.
17+
18+
Args:
19+
event: The Lambda event containing records
20+
service_type: The service type string (example: sqs', 'sns')
21+
arn_extractor: Function to extract the ARN from the record
22+
payload_size_calculator: Function to calculate payload size
23+
"""
1224
from datadog_lambda.wrapper import format_err_with_traceback
1325
from ddtrace.internal.datastreams import data_streams_processor
1426
from ddtrace.internal.datastreams.processor import DsmPathwayCodec
15-
from ddtrace.internal.datastreams.botocore import (
16-
get_datastreams_context,
17-
calculate_sqs_payload_size,
18-
)
27+
from ddtrace.internal.datastreams.botocore import get_datastreams_context
1928

2029
records = event.get("Records")
2130
if records is None:
@@ -24,15 +33,41 @@ def _dsm_set_sqs_context(event):
2433

2534
for record in records:
2635
try:
27-
queue_arn = record.get("eventSourceARN", "")
28-
29-
contextjson = get_datastreams_context(record)
30-
payload_size = calculate_sqs_payload_size(record)
36+
arn = arn_extractor(record)
37+
context_json = get_datastreams_context(record)
38+
payload_size = payload_size_calculator(record, context_json)
3139

32-
ctx = DsmPathwayCodec.decode(contextjson, processor)
40+
ctx = DsmPathwayCodec.decode(context_json, processor)
3341
ctx.set_checkpoint(
34-
["direction:in", f"topic:{queue_arn}", "type:sqs"],
42+
["direction:in", f"topic:{arn}", f"type:{service_type}"],
3543
payload_size=payload_size,
3644
)
3745
except Exception as e:
3846
logger.error(format_err_with_traceback(e))
47+
48+
49+
def _dsm_set_sns_context(event):
50+
from ddtrace.internal.datastreams.botocore import calculate_sns_payload_size
51+
52+
def sns_payload_calculator(record, context_json):
53+
return calculate_sns_payload_size(record, context_json)
54+
55+
def sns_arn_extractor(record):
56+
sns_data = record.get("Sns")
57+
if not sns_data:
58+
return ""
59+
return sns_data.get("TopicArn", "")
60+
61+
_dsm_set_context_helper(event, "sns", sns_arn_extractor, sns_payload_calculator)
62+
63+
64+
def _dsm_set_sqs_context(event):
65+
from ddtrace.internal.datastreams.botocore import calculate_sqs_payload_size
66+
67+
def sqs_payload_calculator(record, context_json):
68+
return calculate_sqs_payload_size(record)
69+
70+
def sqs_arn_extractor(record):
71+
return record.get("eventSourceARN", "")
72+
73+
_dsm_set_context_helper(event, "sqs", sqs_arn_extractor, sqs_payload_calculator)

tests/test_dsm.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
import unittest
22
from unittest.mock import patch, MagicMock
33

4-
from datadog_lambda.dsm import set_dsm_context, _dsm_set_sqs_context
4+
from datadog_lambda.dsm import (
5+
set_dsm_context,
6+
_dsm_set_sqs_context,
7+
_dsm_set_sns_context,
8+
)
59
from datadog_lambda.trigger import EventTypes, _EventSource
610

711

8-
class TestDsmSQSContext(unittest.TestCase):
12+
class TestDSMContext(unittest.TestCase):
913
def setUp(self):
1014
patcher = patch("datadog_lambda.dsm._dsm_set_sqs_context")
1115
self.mock_dsm_set_sqs_context = patcher.start()
1216
self.addCleanup(patcher.stop)
1317

18+
patcher = patch("datadog_lambda.dsm._dsm_set_sns_context")
19+
self.mock_dsm_set_sns_context = patcher.start()
20+
self.addCleanup(patcher.stop)
21+
1422
patcher = patch("ddtrace.internal.datastreams.data_streams_processor")
1523
self.mock_data_streams_processor = patcher.start()
1624
self.addCleanup(patcher.stop)
@@ -27,6 +35,13 @@ def setUp(self):
2735
self.mock_calculate_sqs_payload_size.return_value = 100
2836
self.addCleanup(patcher.stop)
2937

38+
patcher = patch(
39+
"ddtrace.internal.datastreams.botocore.calculate_sns_payload_size"
40+
)
41+
self.mock_calculate_sns_payload_size = patcher.start()
42+
self.mock_calculate_sns_payload_size.return_value = 150
43+
self.addCleanup(patcher.stop)
44+
3045
patcher = patch("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode")
3146
self.mock_dsm_pathway_codec_decode = patcher.start()
3247
self.addCleanup(patcher.stop)
@@ -110,3 +125,81 @@ def test_sqs_multiple_records_process_each_record(self):
110125
self.assertIn(f"topic:{expected_arns[i]}", tags)
111126
self.assertIn("type:sqs", tags)
112127
self.assertEqual(kwargs["payload_size"], 100)
128+
129+
def test_sns_event_with_no_records_does_nothing(self):
130+
"""Test that events where Records is None don't trigger DSM processing"""
131+
events_with_no_records = [
132+
{},
133+
{"Records": None},
134+
{"someOtherField": "value"},
135+
]
136+
137+
for event in events_with_no_records:
138+
_dsm_set_sns_context(event)
139+
self.mock_data_streams_processor.assert_not_called()
140+
141+
def test_sns_event_triggers_dsm_sns_context(self):
142+
"""Test that SNS event sources trigger the SNS-specific DSM context function"""
143+
sns_event = {
144+
"Records": [
145+
{
146+
"EventSource": "aws:sns",
147+
"Sns": {
148+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:my-topic",
149+
"Message": "Hello from SNS!",
150+
},
151+
}
152+
]
153+
}
154+
155+
event_source = _EventSource(EventTypes.SNS)
156+
set_dsm_context(sns_event, event_source)
157+
158+
self.mock_dsm_set_sns_context.assert_called_once_with(sns_event)
159+
160+
def test_sns_multiple_records_process_each_record(self):
161+
"""Test that each record in an SNS event gets processed individually"""
162+
multi_record_event = {
163+
"Records": [
164+
{
165+
"Sns": {
166+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:topic1",
167+
"Message": "Message 1",
168+
}
169+
},
170+
{
171+
"Sns": {
172+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:topic2",
173+
"Message": "Message 2",
174+
}
175+
},
176+
{
177+
"Sns": {
178+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:topic3",
179+
"Message": "Message 3",
180+
}
181+
},
182+
]
183+
}
184+
185+
mock_context = MagicMock()
186+
self.mock_dsm_pathway_codec_decode.return_value = mock_context
187+
188+
_dsm_set_sns_context(multi_record_event)
189+
190+
self.assertEqual(mock_context.set_checkpoint.call_count, 3)
191+
192+
calls = mock_context.set_checkpoint.call_args_list
193+
expected_arns = [
194+
"arn:aws:sns:us-east-1:123456789012:topic1",
195+
"arn:aws:sns:us-east-1:123456789012:topic2",
196+
"arn:aws:sns:us-east-1:123456789012:topic3",
197+
]
198+
199+
for i, call in enumerate(calls):
200+
args, kwargs = call
201+
tags = args[0]
202+
self.assertIn("direction:in", tags)
203+
self.assertIn(f"topic:{expected_arns[i]}", tags)
204+
self.assertIn("type:sns", tags)
205+
self.assertEqual(kwargs["payload_size"], 150)

0 commit comments

Comments
 (0)