Skip to content

Commit f171204

Browse files
araki-yzrhfour43
authored andcommitted
ELB Support / Fix: APIGW v2 cookies response (Kludex#155)
* fix apigw v2 request cookie header * add v1 cookie header test * support elb * support elb test * refactored adapter event parsing * fix elb single value headers response * fix apigw v2 cookie response * remove old code
1 parent 073d0f8 commit f171204

File tree

4 files changed

+310
-33
lines changed

4 files changed

+310
-33
lines changed

mangum/adapter.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,34 +90,44 @@ def __call__(self, event: dict, context: dict) -> dict:
9090
stack.enter_context(lifespan_cycle)
9191

9292
request_context = event["requestContext"]
93-
if "http" in request_context:
93+
94+
if event.get("multiValueHeaders"):
95+
headers = {k.lower(): ", ".join(v) if isinstance(v, list) else ""
96+
for k, v in event.get("multiValueHeaders", {}).items()}
97+
elif event.get("headers"):
98+
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
99+
else:
100+
headers = {}
101+
102+
# API Gateway v2
103+
if event.get("version") == "2.0":
94104
source_ip = request_context["http"]["sourceIp"]
95105
path = request_context["http"]["path"]
96106
http_method = request_context["http"]["method"]
97107
query_string = event.get("rawQueryString", "").encode()
108+
109+
if event.get("cookies"):
110+
headers["cookie"] = "; ".join(event.get("cookies", []))
111+
112+
# API Gateway v1 / ELB
98113
else:
99-
source_ip = request_context.get("identity", {}).get("sourceIp")
100-
multi_value_query_string_params = event[
101-
"multiValueQueryStringParameters"
102-
]
103-
query_string = (
104-
urllib.parse.urlencode(
105-
multi_value_query_string_params, doseq=True
106-
).encode()
107-
if multi_value_query_string_params
108-
else b""
109-
)
114+
if "elb" in request_context:
115+
# NOTE: trust only the most right side value
116+
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
117+
else:
118+
source_ip = request_context.get("identity", {}).get("sourceIp")
119+
110120
path = event["path"]
111121
http_method = event["httpMethod"]
112122

113-
headers = (
114-
{k.lower(): v for k, v in event.get("headers", {}).items()}
115-
if event.get("headers")
116-
else {}
117-
)
118-
119-
if "cookies" in event:
120-
headers["cookie"] = "; ".join(event.get("cookies", []))
123+
if event.get("multiValueQueryStringParameters"):
124+
query_string = urllib.parse.urlencode(
125+
event.get("multiValueQueryStringParameters", {}), doseq=True).encode()
126+
elif event.get("queryStringParameters"):
127+
query_string = urllib.parse.urlencode(
128+
event.get("queryStringParameters", {})).encode()
129+
else:
130+
query_string = b""
121131

122132
server_name = headers.get("host", "mangum")
123133
if ":" not in server_name:

mangum/protocols/http.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,25 @@
1111
from mangum.exceptions import UnexpectedMessage
1212

1313

14+
def all_casings(input_string):
15+
"""
16+
Permute all casings of a given string.
17+
A pretty algoritm, via @Amber
18+
http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python
19+
"""
20+
if not input_string:
21+
yield ""
22+
else:
23+
first = input_string[:1]
24+
if first.lower() == first.upper():
25+
for sub_casing in all_casings(input_string[1:]):
26+
yield first + sub_casing
27+
else:
28+
for sub_casing in all_casings(input_string[1:]):
29+
yield first.lower() + sub_casing
30+
yield first.upper() + sub_casing
31+
32+
1433
class HTTPCycleState(enum.Enum):
1534
"""
1635
The state of the ASGI `http` connection.
@@ -116,21 +135,47 @@ async def send(self, message: Message) -> None:
116135
self.response["statusCode"] = message["status"]
117136
headers: typing.Dict[str, str] = {}
118137
multi_value_headers: typing.Dict[str, typing.List[str]] = {}
119-
for key, value in message.get("headers", []):
120-
lower_key = key.decode().lower()
121-
if lower_key in multi_value_headers:
122-
multi_value_headers[lower_key].append(value.decode())
123-
elif lower_key in headers:
124-
multi_value_headers[lower_key] = [
125-
headers.pop(lower_key),
126-
value.decode(),
127-
]
128-
else:
129-
headers[lower_key] = value.decode()
138+
cookies: typing.List[str] = []
139+
event = self.scope["aws.event"]
140+
# ELB
141+
if "elb" in event["requestContext"]:
142+
for key, value in message.get("headers", []):
143+
lower_key = key.decode().lower()
144+
if lower_key in multi_value_headers:
145+
multi_value_headers[lower_key].append(value.decode())
146+
else:
147+
multi_value_headers[lower_key] = [value.decode()]
148+
if "multiValueHeaders" not in event:
149+
# If there are multiple occurrences of headers, create case-mutated variations
150+
# see: https://github.com/logandk/serverless-wsgi/issues/11
151+
for key, values in multi_value_headers.items():
152+
if len(values) > 1:
153+
for value, cased_key in zip(values, all_casings(key)):
154+
headers[cased_key] = value
155+
elif len(values) == 1:
156+
headers[key] = values[0]
157+
multi_value_headers = {}
158+
# API Gateway
159+
else:
160+
for key, value in message.get("headers", []):
161+
lower_key = key.decode().lower()
162+
if event.get("version") == "2.0" and lower_key == "set-cookie":
163+
cookies.append(value.decode())
164+
elif lower_key in multi_value_headers:
165+
multi_value_headers[lower_key].append(value.decode())
166+
elif lower_key in headers:
167+
multi_value_headers[lower_key] = [
168+
headers.pop(lower_key),
169+
value.decode(),
170+
]
171+
else:
172+
headers[lower_key] = value.decode()
130173

131174
self.response["headers"] = headers
132175
if multi_value_headers:
133176
self.response["multiValueHeaders"] = multi_value_headers
177+
if len(cookies):
178+
self.response["cookies"] = cookies
134179
self.state = HTTPCycleState.RESPONSE
135180

136181
elif (

tests/conftest.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,65 @@ def mock_http_api_event(request):
116116
}
117117

118118
return event
119+
120+
121+
@pytest.fixture
122+
def mock_http_elb_singlevalue_event(request):
123+
method = request.param[0]
124+
body = request.param[1]
125+
multi_value_query_parameters = request.param[2]
126+
event = {
127+
"requestContext": {
128+
"elb": {
129+
"targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:0:targetgroup/test/0"
130+
}
131+
},
132+
"httpMethod": method,
133+
"path": "/my/path",
134+
"queryStringParameters": {
135+
k: v[-1] for k, v in multi_value_query_parameters.items()
136+
}
137+
if multi_value_query_parameters
138+
else None,
139+
"headers": {
140+
"accept-encoding": "gzip, deflate",
141+
"cookie": "cookie1; cookie2",
142+
"host": "test.execute-api.us-west-2.amazonaws.com",
143+
"x-forwarded-for": "192.168.100.3, 192.168.100.2, 192.168.100.1",
144+
"x-forwarded-port": "443",
145+
"x-forwarded-proto": "https",
146+
},
147+
"body": body,
148+
"isBase64Encoded": False
149+
}
150+
151+
return event
152+
153+
154+
@pytest.fixture
155+
def mock_http_elb_multivalue_event(request):
156+
method = request.param[0]
157+
body = request.param[1]
158+
multi_value_query_parameters = request.param[2]
159+
event = {
160+
"requestContext": {
161+
"elb": {
162+
"targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:0:targetgroup/test/0"
163+
}
164+
},
165+
"httpMethod": method,
166+
"path": "/my/path",
167+
"multiValueQueryStringParameters": multi_value_query_parameters or None,
168+
"multiValueHeaders": {
169+
"accept-encoding": ["gzip, deflate"],
170+
"cookie": ["cookie1; cookie2"],
171+
"host": ["test.execute-api.us-west-2.amazonaws.com"],
172+
"x-forwarded-for": ["192.168.100.3, 192.168.100.2, 192.168.100.1"],
173+
"x-forwarded-port": ["443"],
174+
"x-forwarded-proto": ["https"],
175+
},
176+
"body": body,
177+
"isBase64Encoded": False
178+
}
179+
180+
return event

0 commit comments

Comments
 (0)