Skip to content

Commit 881568d

Browse files
authored
#185 add case mutation code back for alb when multi value headers is not set (#186)
1 parent 619790e commit 881568d

File tree

2 files changed

+84
-6
lines changed

2 files changed

+84
-6
lines changed

mangum/handlers/aws_alb.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,30 @@
11
import base64
22
import urllib.parse
3-
from typing import Dict, Any
3+
from typing import Any, Dict, Generator, List, Tuple
44

55
from .abstract_handler import AbstractHandler
66
from .. import Response, Request
77

88

9+
def all_casings(input_string: str) -> Generator:
10+
"""
11+
Permute all casings of a given string.
12+
A pretty algoritm, via @Amber
13+
http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python
14+
"""
15+
if not input_string:
16+
yield ""
17+
else:
18+
first = input_string[:1]
19+
if first.lower() == first.upper():
20+
for sub_casing in all_casings(input_string[1:]):
21+
yield first + sub_casing
22+
else:
23+
for sub_casing in all_casings(input_string[1:]):
24+
yield first.lower() + sub_casing
25+
yield first.upper() + sub_casing
26+
27+
928
class AwsAlb(AbstractHandler):
1029
"""
1130
Handles AWS Elastic Load Balancer, really Application Load Balancer events
@@ -66,10 +85,27 @@ def body(self) -> bytes:
6685

6786
return body
6887

69-
def transform_response(self, response: Response) -> Dict[str, Any]:
88+
def handle_headers(
89+
self,
90+
response_headers: List[List[bytes]],
91+
) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
7092
headers, multi_value_headers = self._handle_multi_value_headers(
71-
response.headers
93+
response_headers
7294
)
95+
if "multiValueHeaders" not in self.trigger_event:
96+
# If there are multiple occurrences of headers, create case-mutated
97+
# variations: https://github.com/logandk/serverless-wsgi/issues/11
98+
for key, values in multi_value_headers.items():
99+
if len(values) > 1:
100+
for value, cased_key in zip(values, all_casings(key)):
101+
headers[cased_key] = value
102+
103+
multi_value_headers = {}
104+
105+
return headers, multi_value_headers
106+
107+
def transform_response(self, response: Response) -> Dict[str, Any]:
108+
headers, multi_value_headers = self.handle_headers(response.headers)
73109

74110
body, is_base64_encoded = self._handle_base64_response_body(
75111
response.body, headers

tests/handlers/test_aws_alb.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55

66

77
def get_mock_aws_alb_event(
8-
method, path, multi_value_query_parameters, body, body_base64_encoded
8+
method,
9+
path,
10+
multi_value_query_parameters,
11+
body,
12+
body_base64_encoded,
13+
multi_value_headers=True,
914
):
10-
return {
15+
event = {
1116
"requestContext": {
1217
"elb": {
1318
"targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" # noqa: E501
@@ -38,6 +43,10 @@ def get_mock_aws_alb_event(
3843
"body": body,
3944
"isBase64Encoded": body_base64_encoded,
4045
}
46+
if multi_value_headers:
47+
event["multiValueHeaders"] = {}
48+
49+
return event
4150

4251

4352
def test_aws_alb_basic():
@@ -226,7 +235,7 @@ def test_aws_alb_scope_real(
226235
assert handler.body == b""
227236

228237

229-
def test_aws_alb_set_cookies() -> None:
238+
def test_aws_alb_set_cookies_multiValueHeaders() -> None:
230239
async def app(scope, receive, send):
231240
await send(
232241
{
@@ -255,6 +264,39 @@ async def app(scope, receive, send):
255264
}
256265

257266

267+
def test_aws_alb_set_cookies_headers() -> None:
268+
async def app(scope, receive, send):
269+
await send(
270+
{
271+
"type": "http.response.start",
272+
"status": 200,
273+
"headers": [
274+
[b"content-type", b"text/plain; charset=utf-8"],
275+
[b"set-cookie", b"cookie1=cookie1; Secure"],
276+
[b"set-cookie", b"cookie2=cookie2; Secure"],
277+
],
278+
}
279+
)
280+
await send({"type": "http.response.body", "body": b"Hello, world!"})
281+
282+
handler = Mangum(app, lifespan="off")
283+
event = get_mock_aws_alb_event(
284+
"GET", "/test", {}, None, False, multi_value_headers=False
285+
)
286+
response = handler(event, {})
287+
assert response == {
288+
"statusCode": 200,
289+
"isBase64Encoded": False,
290+
"headers": {
291+
"content-type": "text/plain; charset=utf-8",
292+
"set-cookie": "cookie1=cookie1; Secure",
293+
"Set-cookie": "cookie2=cookie2; Secure",
294+
},
295+
"multiValueHeaders": {},
296+
"body": "Hello, world!",
297+
}
298+
299+
258300
@pytest.mark.parametrize(
259301
"method,content_type,raw_res_body,res_body,res_base64_encoded",
260302
[

0 commit comments

Comments
 (0)