Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 61 additions & 31 deletions mangum/handlers/aws_alb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import urllib.parse
from typing import Any, Dict, Generator, List, Tuple
from itertools import islice

from .abstract_handler import AbstractHandler
from .. import Response, Request
Expand All @@ -25,12 +26,25 @@ def all_casings(input_string: str) -> Generator:
yield first.upper() + sub_casing


def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]:
"""Create str/str key/value headers, with duplicate keys case mutated."""
headers = {}
for key, values in multi_value_headers.items():
if len(values) > 0:
casings = list(islice(all_casings(key), len(values)))
for value, cased_key in zip(values, casings):
headers[cased_key] = value
return headers


class AwsAlb(AbstractHandler):
"""
Handles AWS Elastic Load Balancer, really Application Load Balancer events
transforming them into ASGI Scope and handling responses

See: https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
See:
1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501
"""

TYPE = "AWS_ALB"
Expand Down Expand Up @@ -71,22 +85,40 @@ def encode_query_string(self) -> bytes:

return urllib.parse.urlencode(query).encode()

def transform_headers(self) -> List[Tuple[bytes, bytes]]:
"""Convert headers to a list of two-tuples per ASGI spec.

Only one of `multiValueHeaders` or `headers` should be defined in the
trigger event. However, we act as though they both might exist and pull
Comment thread
jordaneremieff marked this conversation as resolved.
headers out of both.
"""
headers = []
if "multiValueHeaders" in self.trigger_event:
for k, v in self.trigger_event["multiValueHeaders"].items():
for inner_v in v:
headers.append((k.lower().encode(), inner_v.encode()))
else:
for k, v in self.trigger_event["headers"].items():
headers.append((k.lower().encode(), v.encode()))
return headers

@property
def request(self) -> Request:
event = self.trigger_event

headers = {}
if event.get("headers"):
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
headers = self.transform_headers()
list_headers = [list(x) for x in headers]
# Unique headers. If there are duplicates, it will use the last defined.
uq_headers = {k.decode(): v.decode() for k, v in headers}

source_ip = headers.get("x-forwarded-for", "")
source_ip = uq_headers.get("x-forwarded-for", "")
path = event["path"]
http_method = event["httpMethod"]
query_string = self.encode_query_string()

server_name = headers.get("host", "mangum")
server_name = uq_headers.get("host", "mangum")
if ":" not in server_name:
server_port = headers.get("x-forwarded-port", 80)
server_port = uq_headers.get("x-forwarded-port", 80)
else:
server_name, server_port = server_name.split(":") # pragma: no cover
server = (server_name, int(server_port))
Expand All @@ -97,9 +129,9 @@ def request(self) -> Request:

return Request(
method=http_method,
headers=[[k.encode(), v.encode()] for k, v in headers.items()],
headers=list_headers,
path=urllib.parse.unquote(path),
scheme=headers.get("x-forwarded-proto", "https"),
scheme=uq_headers.get("x-forwarded-proto", "https"),
query_string=query_string,
server=server,
client=client,
Expand All @@ -119,36 +151,34 @@ def body(self) -> bytes:

return body

def handle_headers(
self,
response_headers: List[List[bytes]],
) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
headers, multi_value_headers = self._handle_multi_value_headers(
response_headers
)
if "multiValueHeaders" not in self.trigger_event:
# If there are multiple occurrences of headers, create case-mutated
# variations: https://github.com/logandk/serverless-wsgi/issues/11
for key, values in multi_value_headers.items():
if len(values) > 1:
for value, cased_key in zip(values, all_casings(key)):
headers[cased_key] = value

multi_value_headers = {}
def transform_response(self, response: Response) -> Dict[str, Any]:

return headers, multi_value_headers
multi_value_headers: Dict[str, List[str]] = {}
for key, value in response.headers:
lower_key = key.decode().lower()
if lower_key not in multi_value_headers:
multi_value_headers[lower_key] = []
multi_value_headers[lower_key].append(value.decode())

def transform_response(self, response: Response) -> Dict[str, Any]:
headers, multi_value_headers = self.handle_headers(response.headers)
headers = case_mutated_headers(multi_value_headers)

body, is_base64_encoded = self._handle_base64_response_body(
response.body, headers
)

return {
out = {
"statusCode": response.status,
"headers": headers,
"multiValueHeaders": multi_value_headers,
"body": body,
"isBase64Encoded": is_base64_encoded,
}

# "You must use multiValueHeaders if you have enabled multi-value headers
# and headers otherwise"
# https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html
multi_value_headers_enabled = "multiValueHeaders" in self.trigger_event
if multi_value_headers_enabled:
out["multiValueHeaders"] = multi_value_headers
else:
out["headers"] = headers

return out
Loading