Skip to content

Commit b171bda

Browse files
authored
[APPSEC]: blocking for alb multi value headers events (#655)
* fix: alb multi val headers also has multi val query params * appsec: fix blocking for alb multivalue headers events
1 parent 9c93e47 commit b171bda

File tree

5 files changed

+79
-22
lines changed

5 files changed

+79
-22
lines changed

datadog_lambda/asm.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,16 @@ def asm_start_request(
7373
route: Optional[str] = None
7474

7575
if event_source.event_type == EventTypes.ALB:
76-
headers = event.get("headers")
77-
multi_value_request_headers = event.get("multiValueHeaders")
78-
if multi_value_request_headers:
79-
request_headers = _to_single_value_headers(multi_value_request_headers)
80-
else:
81-
request_headers = headers or {}
82-
8376
raw_uri = event.get("path")
84-
parsed_query = event.get("multiValueQueryStringParameters") or event.get(
85-
"queryStringParameters"
86-
)
77+
78+
if event_source.subtype == EventSubtypes.ALB:
79+
request_headers = event.get("headers", {})
80+
parsed_query = event.get("queryStringParameters")
81+
if event_source.subtype == EventSubtypes.ALB_MULTI_VALUE_HEADERS:
82+
request_headers = _to_single_value_headers(
83+
event.get("multiValueHeaders", {})
84+
)
85+
parsed_query = event.get("multiValueQueryStringParameters")
8786

8887
elif event_source.event_type == EventTypes.LAMBDA_FUNCTION_URL:
8988
request_headers = event.get("headers", {})
@@ -226,15 +225,27 @@ def get_asm_blocked_response(
226225
content_type = blocked.get("content-type", "application/json")
227226
content = http_utils._get_blocked_template(content_type)
228227

229-
response_headers = {
230-
"content-type": content_type,
231-
}
232-
if "location" in blocked:
233-
response_headers["location"] = blocked["location"]
234-
235-
return {
228+
response = {
236229
"statusCode": blocked.get("status_code", 403),
237-
"headers": response_headers,
238230
"body": content,
239231
"isBase64Encoded": False,
240232
}
233+
234+
needs_multi_value_headers = event_source.equals(
235+
EventTypes.ALB, EventSubtypes.ALB_MULTI_VALUE_HEADERS
236+
)
237+
238+
if needs_multi_value_headers:
239+
response["multiValueHeaders"] = {
240+
"content-type": [content_type],
241+
}
242+
if "location" in blocked:
243+
response["multiValueHeaders"]["location"] = [blocked["location"]]
244+
else:
245+
response["headers"] = {
246+
"content-type": content_type,
247+
}
248+
if "location" in blocked:
249+
response["headers"]["location"] = blocked["location"]
250+
251+
return response

datadog_lambda/trigger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ class EventSubtypes(_stringTypedEnum):
5454
WEBSOCKET = "websocket"
5555
HTTP_API = "http-api"
5656

57+
ALB = "alb" # regular alb
58+
# ALB with the multi-value headers option checked on the target group
59+
ALB_MULTI_VALUE_HEADERS = "alb-multi-value-headers"
60+
5761

5862
class _EventSource:
5963
"""
@@ -133,7 +137,12 @@ def parse_event_source(event: dict) -> _EventSource:
133137
event_source.subtype = EventSubtypes.WEBSOCKET
134138

135139
if request_context and request_context.get("elb"):
136-
event_source = _EventSource(EventTypes.ALB)
140+
if "multiValueHeaders" in event:
141+
event_source = _EventSource(
142+
EventTypes.ALB, EventSubtypes.ALB_MULTI_VALUE_HEADERS
143+
)
144+
else:
145+
event_source = _EventSource(EventTypes.ALB, EventSubtypes.ALB)
137146

138147
if event.get("awslogs"):
139148
event_source = _EventSource(EventTypes.CLOUDWATCH_LOGS)

tests/event_samples/application-load-balancer-mutivalue-headers.json renamed to tests/event_samples/application-load-balancer-multivalue-headers.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
},
77
"httpMethod": "GET",
88
"path": "/lambda",
9-
"queryStringParameters": {
9+
"multiValueQueryStringParameters": {
1010
"query": "1234ABCD"
1111
},
1212
"multiValueHeaders": {

tests/test_asm.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
get_asm_blocked_response,
99
)
1010
from datadog_lambda.trigger import (
11+
EventSubtypes,
1112
EventTypes,
1213
_EventSource,
1314
extract_trigger_tags,
@@ -34,7 +35,7 @@
3435
),
3536
(
3637
"application_load_balancer_multivalue_headers",
37-
"application-load-balancer-mutivalue-headers.json",
38+
"application-load-balancer-multivalue-headers.json",
3839
"72.12.164.125",
3940
"/lambda?query=1234ABCD",
4041
"GET",
@@ -111,7 +112,7 @@
111112
),
112113
(
113114
"application_load_balancer_multivalue_headers",
114-
"application-load-balancer-mutivalue-headers.json",
115+
"application-load-balancer-multivalue-headers.json",
115116
{
116117
"statusCode": 404,
117118
"multiValueHeaders": {
@@ -397,6 +398,25 @@ def test_get_asm_blocked_response_blocked(
397398
response = get_asm_blocked_response(event_source)
398399
assert response["statusCode"] == expected_status
399400
assert response["headers"] == expected_headers
401+
assert "multiValueHeaders" not in response
402+
403+
404+
@patch("datadog_lambda.asm.get_blocked")
405+
def test_get_asm_blocked_response_blocked_multi_value_headers(
406+
mock_get_blocked,
407+
):
408+
# HTML blocking response
409+
mock_get_blocked.return_value = {
410+
"status_code": 401,
411+
"type": "html",
412+
"content-type": "text/html",
413+
}
414+
415+
event_source = _EventSource(EventTypes.ALB, EventSubtypes.ALB_MULTI_VALUE_HEADERS)
416+
response = get_asm_blocked_response(event_source)
417+
assert response["statusCode"] == 401
418+
assert response["multiValueHeaders"] == {"content-type": ["text/html"]}
419+
assert "headers" not in response
400420

401421

402422
@patch("datadog_lambda.asm.get_blocked")

tests/test_trigger.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from datadog_lambda.trigger import (
77
EventSubtypes,
8+
EventTypes,
89
parse_event_source,
910
get_event_source_arn,
1011
extract_trigger_tags,
@@ -117,6 +118,22 @@ def test_event_source_application_load_balancer(self):
117118
event_source = parse_event_source(event)
118119
event_source_arn = get_event_source_arn(event_source, event, ctx)
119120
self.assertEqual(event_source.to_string(), event_sample_source)
121+
self.assertEqual(event_source.subtype, EventSubtypes.ALB)
122+
self.assertEqual(
123+
event_source_arn,
124+
"arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-xyz/123abc",
125+
)
126+
127+
def test_event_source_application_load_balancer_multi_value_headers(self):
128+
event_sample_source = "application-load-balancer-multivalue-headers"
129+
test_file = event_samples + event_sample_source + ".json"
130+
with open(test_file, "r") as event:
131+
event = json.load(event)
132+
ctx = get_mock_context()
133+
event_source = parse_event_source(event)
134+
event_source_arn = get_event_source_arn(event_source, event, ctx)
135+
self.assertEqual(event_source.event_type, EventTypes.ALB)
136+
self.assertEqual(event_source.subtype, EventSubtypes.ALB_MULTI_VALUE_HEADERS)
120137
self.assertEqual(
121138
event_source_arn,
122139
"arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-xyz/123abc",

0 commit comments

Comments
 (0)