Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions mangum/handlers/aws_alb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
import base64
import urllib.parse
from typing import Dict, Any
from typing import Any, Dict, Generator, List, Tuple

from .abstract_handler import AbstractHandler
from .. import Response, Request


def all_casings(input_string: str) -> Generator:
"""
Permute all casings of a given string.
A pretty algoritm, via @Amber
http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python
"""
if not input_string:
yield ""
else:
first = input_string[:1]
if first.lower() == first.upper():
for sub_casing in all_casings(input_string[1:]):
yield first + sub_casing
else:
for sub_casing in all_casings(input_string[1:]):
yield first.lower() + sub_casing
yield first.upper() + sub_casing


class AwsAlb(AbstractHandler):
"""
Handles AWS Elastic Load Balancer, really Application Load Balancer events
Expand Down Expand Up @@ -66,10 +85,27 @@ def body(self) -> bytes:

return body

def transform_response(self, response: Response) -> Dict[str, Any]:
def handle_headers(
self,
response_headers: List[List[bytes]],
) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
headers, multi_value_headers = self._handle_multi_value_headers(
response.headers
response_headers
)
if "multiValueHeaders" not in self.trigger_event:
# If there are multiple occurrences of headers, create case-mutated
# variations: https://github.com/logandk/serverless-wsgi/issues/11
for key, values in multi_value_headers.items():
if len(values) > 1:
for value, cased_key in zip(values, all_casings(key)):
headers[cased_key] = value

multi_value_headers = {}

return headers, multi_value_headers

def transform_response(self, response: Response) -> Dict[str, Any]:
headers, multi_value_headers = self.handle_headers(response.headers)

body, is_base64_encoded = self._handle_base64_response_body(
response.body, headers
Expand Down
48 changes: 45 additions & 3 deletions tests/handlers/test_aws_alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@


def get_mock_aws_alb_event(
method, path, multi_value_query_parameters, body, body_base64_encoded
method,
path,
multi_value_query_parameters,
body,
body_base64_encoded,
multi_value_headers=True,
):
return {
event = {
"requestContext": {
"elb": {
"targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" # noqa: E501
Expand Down Expand Up @@ -38,6 +43,10 @@ def get_mock_aws_alb_event(
"body": body,
"isBase64Encoded": body_base64_encoded,
}
if multi_value_headers:
event["multiValueHeaders"] = {}
Comment on lines +46 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again this specifies headers and multiValueHeaders


return event


def test_aws_alb_basic():
Expand Down Expand Up @@ -226,7 +235,7 @@ def test_aws_alb_scope_real(
assert handler.body == b""


def test_aws_alb_set_cookies() -> None:
def test_aws_alb_set_cookies_multiValueHeaders() -> None:
async def app(scope, receive, send):
await send(
{
Expand Down Expand Up @@ -255,6 +264,39 @@ async def app(scope, receive, send):
}


def test_aws_alb_set_cookies_headers() -> None:
async def app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
[b"content-type", b"text/plain; charset=utf-8"],
[b"set-cookie", b"cookie1=cookie1; Secure"],
[b"set-cookie", b"cookie2=cookie2; Secure"],
],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})

handler = Mangum(app, lifespan="off")
event = get_mock_aws_alb_event(
"GET", "/test", {}, None, False, multi_value_headers=False
)
response = handler(event, {})
assert response == {
"statusCode": 200,
"isBase64Encoded": False,
"headers": {
"content-type": "text/plain; charset=utf-8",
"set-cookie": "cookie1=cookie1; Secure",
"Set-cookie": "cookie2=cookie2; Secure",
},
"multiValueHeaders": {},
Comment on lines +290 to +295
Copy link
Copy Markdown
Contributor

@jurasofish jurasofish Apr 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one of headers and multiValueHeaders should be specified
"You must use multiValueHeaders if you have enabled multi-value headers and headers otherwise"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation you linked makes it sound like all headers should be in multiValueHeaders, even those with a single value, if it is enabled.

I'm not sure if 0.11.0 did this or did what the current "main" branch does and only send the headers that have multiple values in multiValueHeaders.

I think using one or the other may be the way to go, it's a bit easier to understand what is going on, and is how the ALB documentation reads online. Let us know what AWS says if anything.

"body": "Hello, world!",
}


@pytest.mark.parametrize(
"method,content_type,raw_res_body,res_body,res_base64_encoded",
[
Expand Down