Skip to content

Commit d9ea835

Browse files
authored
Fix ALB multi value headers (#189)
* copy with some slight changes from #179 * move headers conversion from `List[Tuple[bytes, bytes]]` to `List[List[bytes]]` up to where the headers are initially transformed * no longer assert that first case mutated string is lowercase * only pull headers from one of `multiValueHeaders` and `headers`
1 parent 0390502 commit d9ea835

File tree

2 files changed

+200
-182
lines changed

2 files changed

+200
-182
lines changed

mangum/handlers/aws_alb.py

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import base64
22
import urllib.parse
33
from typing import Any, Dict, Generator, List, Tuple
4+
from itertools import islice
45

56
from .abstract_handler import AbstractHandler
67
from .. import Response, Request
@@ -25,12 +26,25 @@ def all_casings(input_string: str) -> Generator:
2526
yield first.upper() + sub_casing
2627

2728

29+
def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]:
30+
"""Create str/str key/value headers, with duplicate keys case mutated."""
31+
headers = {}
32+
for key, values in multi_value_headers.items():
33+
if len(values) > 0:
34+
casings = list(islice(all_casings(key), len(values)))
35+
for value, cased_key in zip(values, casings):
36+
headers[cased_key] = value
37+
return headers
38+
39+
2840
class AwsAlb(AbstractHandler):
2941
"""
3042
Handles AWS Elastic Load Balancer, really Application Load Balancer events
3143
transforming them into ASGI Scope and handling responses
3244
33-
See: https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
45+
See:
46+
1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
47+
2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501
3448
"""
3549

3650
TYPE = "AWS_ALB"
@@ -71,22 +85,40 @@ def encode_query_string(self) -> bytes:
7185

7286
return urllib.parse.urlencode(query).encode()
7387

88+
def transform_headers(self) -> List[Tuple[bytes, bytes]]:
89+
"""Convert headers to a list of two-tuples per ASGI spec.
90+
91+
Only one of `multiValueHeaders` or `headers` should be defined in the
92+
trigger event. However, we act as though they both might exist and pull
93+
headers out of both.
94+
"""
95+
headers = []
96+
if "multiValueHeaders" in self.trigger_event:
97+
for k, v in self.trigger_event["multiValueHeaders"].items():
98+
for inner_v in v:
99+
headers.append((k.lower().encode(), inner_v.encode()))
100+
else:
101+
for k, v in self.trigger_event["headers"].items():
102+
headers.append((k.lower().encode(), v.encode()))
103+
return headers
104+
74105
@property
75106
def request(self) -> Request:
76107
event = self.trigger_event
77108

78-
headers = {}
79-
if event.get("headers"):
80-
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
109+
headers = self.transform_headers()
110+
list_headers = [list(x) for x in headers]
111+
# Unique headers. If there are duplicates, it will use the last defined.
112+
uq_headers = {k.decode(): v.decode() for k, v in headers}
81113

82-
source_ip = headers.get("x-forwarded-for", "")
114+
source_ip = uq_headers.get("x-forwarded-for", "")
83115
path = event["path"]
84116
http_method = event["httpMethod"]
85117
query_string = self.encode_query_string()
86118

87-
server_name = headers.get("host", "mangum")
119+
server_name = uq_headers.get("host", "mangum")
88120
if ":" not in server_name:
89-
server_port = headers.get("x-forwarded-port", 80)
121+
server_port = uq_headers.get("x-forwarded-port", 80)
90122
else:
91123
server_name, server_port = server_name.split(":") # pragma: no cover
92124
server = (server_name, int(server_port))
@@ -97,9 +129,9 @@ def request(self) -> Request:
97129

98130
return Request(
99131
method=http_method,
100-
headers=[[k.encode(), v.encode()] for k, v in headers.items()],
132+
headers=list_headers,
101133
path=urllib.parse.unquote(path),
102-
scheme=headers.get("x-forwarded-proto", "https"),
134+
scheme=uq_headers.get("x-forwarded-proto", "https"),
103135
query_string=query_string,
104136
server=server,
105137
client=client,
@@ -119,36 +151,34 @@ def body(self) -> bytes:
119151

120152
return body
121153

122-
def handle_headers(
123-
self,
124-
response_headers: List[List[bytes]],
125-
) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
126-
headers, multi_value_headers = self._handle_multi_value_headers(
127-
response_headers
128-
)
129-
if "multiValueHeaders" not in self.trigger_event:
130-
# If there are multiple occurrences of headers, create case-mutated
131-
# variations: https://github.com/logandk/serverless-wsgi/issues/11
132-
for key, values in multi_value_headers.items():
133-
if len(values) > 1:
134-
for value, cased_key in zip(values, all_casings(key)):
135-
headers[cased_key] = value
136-
137-
multi_value_headers = {}
154+
def transform_response(self, response: Response) -> Dict[str, Any]:
138155

139-
return headers, multi_value_headers
156+
multi_value_headers: Dict[str, List[str]] = {}
157+
for key, value in response.headers:
158+
lower_key = key.decode().lower()
159+
if lower_key not in multi_value_headers:
160+
multi_value_headers[lower_key] = []
161+
multi_value_headers[lower_key].append(value.decode())
140162

141-
def transform_response(self, response: Response) -> Dict[str, Any]:
142-
headers, multi_value_headers = self.handle_headers(response.headers)
163+
headers = case_mutated_headers(multi_value_headers)
143164

144165
body, is_base64_encoded = self._handle_base64_response_body(
145166
response.body, headers
146167
)
147168

148-
return {
169+
out = {
149170
"statusCode": response.status,
150-
"headers": headers,
151-
"multiValueHeaders": multi_value_headers,
152171
"body": body,
153172
"isBase64Encoded": is_base64_encoded,
154173
}
174+
175+
# "You must use multiValueHeaders if you have enabled multi-value headers
176+
# and headers otherwise"
177+
# https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html
178+
multi_value_headers_enabled = "multiValueHeaders" in self.trigger_event
179+
if multi_value_headers_enabled:
180+
out["multiValueHeaders"] = multi_value_headers
181+
else:
182+
out["headers"] = headers
183+
184+
return out

0 commit comments

Comments
 (0)