Skip to content

Commit 1a5d1db

Browse files
authored
Fix a bug in s3 handler to handle SNS records (#758)
1 parent d48a961 commit 1a5d1db

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

aws/logs_monitoring/steps/handlers/s3_handler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ def s3_handler(event, context, metadata, cache_layer):
4444
# Get the S3 client
4545
s3 = get_s3_client()
4646
# if this is a S3 event carried in a SNS message, extract it and override the event
47-
first_record = event.get("Records")[0]
48-
if sns := first_record.get("Sns"):
49-
event = json.loads(sns.get("Message"))
47+
if "Sns" in event.get("Records")[0]:
48+
event = json.loads(event.get("Records")[0].get("Sns").get("Message"))
5049
# Get the object from the event and show its content type
51-
bucket = first_record.get("s3").get("bucket").get("name")
52-
key = urllib.parse.unquote_plus(first_record.get("s3").get("object").get("key"))
50+
bucket = event.get("Records")[0].get("s3").get("bucket").get("name")
51+
key = urllib.parse.unquote_plus(
52+
event.get("Records")[0].get("s3").get("object").get("key")
53+
)
5354
source = set_source(event, metadata, bucket, key)
5455
# Add Service tag
5556
add_service_tag(metadata)

aws/logs_monitoring/tests/test_s3_handler.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import gzip
22
import unittest
3+
from unittest.mock import MagicMock, patch
34
from approvaltests.combination_approvals import verify_all_combinations
45
from steps.handlers.s3_handler import (
6+
s3_handler,
57
parse_service_arn,
68
get_partition_from_region,
79
get_structured_lines_for_s3_handler,
810
)
911

1012

1113
class TestS3EventsHandler(unittest.TestCase):
14+
class Context:
15+
function_version = 0
16+
invoked_function_arn = "invoked_function_arn"
17+
function_name = "function_name"
18+
memory_limit_in_mb = "10"
19+
1220
def parse_lines(self, data, key, source):
1321
bucket = "my-bucket"
1422
gzip_data = gzip.compress(bytes(data, "utf-8"))
@@ -59,6 +67,65 @@ def test_get_partition_from_region(self):
5967
self.assertEqual(get_partition_from_region("cn-north-1"), "aws-cn")
6068
self.assertEqual(get_partition_from_region(None), "aws")
6169

70+
@patch("steps.handlers.s3_handler.extract_data")
71+
@patch("steps.handlers.s3_handler.get_s3_client")
72+
def test_s3_handler(self, mock_s3_client, extract_data):
73+
event = {
74+
"Records": [
75+
{
76+
"s3": {
77+
"bucket": {"name": "my-bucket"},
78+
"object": {"key": "my-key"},
79+
}
80+
}
81+
]
82+
}
83+
context = self.Context()
84+
metadata = {"ddtags": ""}
85+
extract_data.side_effect = [("data".encode("utf-8"))]
86+
cache_layer = MagicMock()
87+
structured_lines = list(s3_handler(event, context, metadata, cache_layer))
88+
self.assertEqual(
89+
structured_lines,
90+
[
91+
{
92+
"aws": {"s3": {"bucket": "my-bucket", "key": "my-key"}},
93+
"message": "data",
94+
}
95+
],
96+
)
97+
self.assertEqual(metadata["ddsource"], "s3")
98+
self.assertEqual(metadata["host"], "arn:aws:s3:::my-bucket")
99+
100+
@patch("steps.handlers.s3_handler.extract_data")
101+
@patch("steps.handlers.s3_handler.get_s3_client")
102+
def test_s3_handler_with_sns(self, mock_s3_client, extract_data):
103+
event = {
104+
"Records": [
105+
{
106+
"Sns": {
107+
"Message": '{"Records": [{"s3": {"bucket": {"name": "my-bucket"}, "object": {"key": "sns-my-key"}}}]}'
108+
}
109+
}
110+
]
111+
}
112+
context = self.Context()
113+
metadata = {"ddtags": ""}
114+
extract_data.side_effect = [("data".encode("utf-8"))]
115+
cache_layer = MagicMock()
116+
structured_lines = list(s3_handler(event, context, metadata, cache_layer))
117+
self.assertEqual(
118+
structured_lines,
119+
[
120+
{
121+
"aws": {"s3": {"bucket": "my-bucket", "key": "sns-my-key"}},
122+
"message": "data",
123+
}
124+
],
125+
)
126+
self.assertEqual(metadata["ddsource"], "s3")
127+
self.assertEqual(metadata["host"], "arn:aws:s3:::my-bucket")
128+
62129

63130
class TestParseServiceArn(unittest.TestCase):
64131
def test_elb_s3_key_invalid(self):

0 commit comments

Comments
 (0)