diff --git a/mangum/handlers/aws_alb.py b/mangum/handlers/aws_alb.py index 318f9f38..4a918f00 100644 --- a/mangum/handlers/aws_alb.py +++ b/mangum/handlers/aws_alb.py @@ -1,6 +1,7 @@ import base64 import urllib.parse from typing import Any, Dict, Generator, List, Tuple +from itertools import islice from .abstract_handler import AbstractHandler from .. import Response, Request @@ -25,12 +26,25 @@ def all_casings(input_string: str) -> Generator: yield first.upper() + sub_casing +def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]: + """Create str/str key/value headers, with duplicate keys case mutated.""" + headers = {} + for key, values in multi_value_headers.items(): + if len(values) > 0: + casings = list(islice(all_casings(key), len(values))) + for value, cased_key in zip(values, casings): + headers[cased_key] = value + return headers + + class AwsAlb(AbstractHandler): """ Handles AWS Elastic Load Balancer, really Application Load Balancer events transforming them into ASGI Scope and handling responses - See: https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html + See: + 1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html + 2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501 """ TYPE = "AWS_ALB" @@ -71,22 +85,40 @@ def encode_query_string(self) -> bytes: return urllib.parse.urlencode(query).encode() + def transform_headers(self) -> List[Tuple[bytes, bytes]]: + """Convert headers to a list of two-tuples per ASGI spec. + + Only one of `multiValueHeaders` or `headers` should be defined in the + trigger event. However, we act as though they both might exist and pull + headers out of both. + """ + headers = [] + if "multiValueHeaders" in self.trigger_event: + for k, v in self.trigger_event["multiValueHeaders"].items(): + for inner_v in v: + headers.append((k.lower().encode(), inner_v.encode())) + else: + for k, v in self.trigger_event["headers"].items(): + headers.append((k.lower().encode(), v.encode())) + return headers + @property def request(self) -> Request: event = self.trigger_event - headers = {} - if event.get("headers"): - headers = {k.lower(): v for k, v in event.get("headers", {}).items()} + headers = self.transform_headers() + list_headers = [list(x) for x in headers] + # Unique headers. If there are duplicates, it will use the last defined. + uq_headers = {k.decode(): v.decode() for k, v in headers} - source_ip = headers.get("x-forwarded-for", "") + source_ip = uq_headers.get("x-forwarded-for", "") path = event["path"] http_method = event["httpMethod"] query_string = self.encode_query_string() - server_name = headers.get("host", "mangum") + server_name = uq_headers.get("host", "mangum") if ":" not in server_name: - server_port = headers.get("x-forwarded-port", 80) + server_port = uq_headers.get("x-forwarded-port", 80) else: server_name, server_port = server_name.split(":") # pragma: no cover server = (server_name, int(server_port)) @@ -97,9 +129,9 @@ def request(self) -> Request: return Request( method=http_method, - headers=[[k.encode(), v.encode()] for k, v in headers.items()], + headers=list_headers, path=urllib.parse.unquote(path), - scheme=headers.get("x-forwarded-proto", "https"), + scheme=uq_headers.get("x-forwarded-proto", "https"), query_string=query_string, server=server, client=client, @@ -119,36 +151,34 @@ def body(self) -> bytes: return body - def handle_headers( - self, - response_headers: List[List[bytes]], - ) -> Tuple[Dict[str, str], Dict[str, List[str]]]: - headers, multi_value_headers = self._handle_multi_value_headers( - response_headers - ) - if "multiValueHeaders" not in self.trigger_event: - # If there are multiple occurrences of headers, create case-mutated - # variations: https://github.com/logandk/serverless-wsgi/issues/11 - for key, values in multi_value_headers.items(): - if len(values) > 1: - for value, cased_key in zip(values, all_casings(key)): - headers[cased_key] = value - - multi_value_headers = {} + def transform_response(self, response: Response) -> Dict[str, Any]: - return headers, multi_value_headers + multi_value_headers: Dict[str, List[str]] = {} + for key, value in response.headers: + lower_key = key.decode().lower() + if lower_key not in multi_value_headers: + multi_value_headers[lower_key] = [] + multi_value_headers[lower_key].append(value.decode()) - def transform_response(self, response: Response) -> Dict[str, Any]: - headers, multi_value_headers = self.handle_headers(response.headers) + headers = case_mutated_headers(multi_value_headers) body, is_base64_encoded = self._handle_base64_response_body( response.body, headers ) - return { + out = { "statusCode": response.status, - "headers": headers, - "multiValueHeaders": multi_value_headers, "body": body, "isBase64Encoded": is_base64_encoded, } + + # "You must use multiValueHeaders if you have enabled multi-value headers + # and headers otherwise" + # https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html + multi_value_headers_enabled = "multiValueHeaders" in self.trigger_event + if multi_value_headers_enabled: + out["multiValueHeaders"] = multi_value_headers + else: + out["headers"] = headers + + return out diff --git a/tests/handlers/test_aws_alb.py b/tests/handlers/test_aws_alb.py index 8fb8cbac..055c1286 100644 --- a/tests/handlers/test_aws_alb.py +++ b/tests/handlers/test_aws_alb.py @@ -1,193 +1,201 @@ +""" +References: +1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html +2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501 +""" import pytest from mangum import Mangum from mangum.handlers import AwsAlb +from typing import Dict, List, Optional def get_mock_aws_alb_event( method, path, - multi_value_query_parameters, + query_parameters: Optional[Dict[str, List[str]]], + headers: Optional[Dict[str, List[str]]], body, body_base64_encoded, - multi_value_headers=True, + multi_value_headers: bool, ): - event = { + """Return a mock AWS ELB event. + + The `query_parameters` parameter must be given in the + `multiValueQueryStringParameters` format - and if `multi_value_headers` + is disabled, then they are simply transformed in to the + `queryStringParameters` format. + Similarly for `headers`. + If `headers` is None, then some defaults will be used. + if `query_parameters` is None, then no query parameters will be used. + """ + resp = { "requestContext": { "elb": { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" # noqa: E501 + "targetGroupArn": ( + "arn:aws:elasticloadbalancing:us-east-2:123456789012:" + "targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" + ) } }, "httpMethod": method, "path": path, - "multiValueQueryStringParameters": multi_value_query_parameters - if multi_value_query_parameters - else {}, - "headers": { - "accept": "text/html,application/xhtml+xml,application/xml;q=0.9," - "image/webp,image/apng,*/*;q=0.8", - "accept-encoding": "gzip", - "accept-language": "en-US,en;q=0.9", - "connection": "keep-alive", - "host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com", - "upgrade-insecure-requests": "1", - "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/71.0.3578.98 Safari/537.36", - "x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476", - "x-forwarded-for": "72.12.164.125", - "x-forwarded-port": "80", - "x-forwarded-proto": "http", - "x-imforwards": "20", - }, "body": body, "isBase64Encoded": body_base64_encoded, } - if multi_value_headers: - event["multiValueHeaders"] = {} - return event + if headers is None: + headers = { + "accept": [ + "text/html,application/xhtml+xml,application/xml;" + "q=0.9,image/webp,image/apng,*/*;q=0.8" + ], + "accept-encoding": ["gzip"], + "accept-language": ["en-US,en;q=0.9"], + "connection": ["keep-alive"], + "host": ["lambda-alb-123578498.us-east-2.elb.amazonaws.com"], + "upgrade-insecure-requests": ["1"], + "user-agent": [ + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36" + ], + "x-amzn-trace-id": ["Root=1-5c536348-3d683b8b04734faae651f476"], + "x-forwarded-for": ["72.12.164.125"], + "x-forwarded-port": ["80"], + "x-forwarded-proto": ["http"], + "x-imforwards": ["20"], + } + query_parameters = {} if query_parameters is None else query_parameters -def test_aws_alb_basic(): - """ - Test the event from the AWS docs - """ - example_event = { - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" # noqa: E501 - } - }, - "httpMethod": "GET", - "path": "/lambda", - "queryStringParameters": { - "q1": "1234ABCD", - "q2": "b c", # not encoded - "q3": "b%20c", # encoded - "q4": "/some/path/", # not encoded - "q5": "%2Fsome%2Fpath%2F", # encoded - }, - "headers": { - "accept": "text/html,application/xhtml+xml,application/xml;q=0.9," - "image/webp,image/apng,*/*;q=0.8", - "accept-encoding": "gzip", - "accept-language": "en-US,en;q=0.9", - "connection": "keep-alive", - "host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com", - "upgrade-insecure-requests": "1", - "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", # noqa: E501 - "x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476", - "x-forwarded-for": "72.12.164.125", - "x-forwarded-port": "80", - "x-forwarded-proto": "http", - "x-imforwards": "20", - }, - "body": "", - "isBase64Encoded": False, - } + # Only set one of `queryStringParameters`/`multiValueQueryStringParameters` + # and one of `headers`/multiValueHeaders (per AWS docs for ALB/lambda) + if multi_value_headers: + resp["multiValueQueryStringParameters"] = query_parameters + resp["multiValueHeaders"] = headers + else: + # Take the last query parameter/cookie (per AWS docs for ALB/lambda) + resp["queryStringParameters"] = { + k: (v[-1] if len(v) > 0 else []) for k, v in query_parameters.items() + } + resp["headers"] = {k: (v[-1] if len(v) > 0 else []) for k, v in headers.items()} - example_context = {} - handler = AwsAlb(example_event, example_context) - - assert type(handler.body) == bytes - assert handler.request.scope == { - "asgi": {"version": "3.0"}, - "aws.context": {}, - "aws.event": example_event, - "aws.eventType": "AWS_ALB", - "client": ("72.12.164.125", 0), - "headers": [ - [ - b"accept", - b"text/html,application/xhtml+xml,application/xml;q=0.9,image/" - b"webp,image/apng,*/*;q=0.8", - ], - [b"accept-encoding", b"gzip"], - [b"accept-language", b"en-US,en;q=0.9"], - [b"connection", b"keep-alive"], - [b"host", b"lambda-alb-123578498.us-east-2.elb.amazonaws.com"], - [b"upgrade-insecure-requests", b"1"], - [ - b"user-agent", - b"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" - b" (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", - ], - [b"x-amzn-trace-id", b"Root=1-5c536348-3d683b8b04734faae651f476"], - [b"x-forwarded-for", b"72.12.164.125"], - [b"x-forwarded-port", b"80"], - [b"x-forwarded-proto", b"http"], - [b"x-imforwards", b"20"], - ], - "http_version": "1.1", - "method": "GET", - "path": "/lambda", - "query_string": b"q1=1234ABCD&q2=b+c&q3=b+c&q4=%2Fsome%2Fpath%2F&q5=%2Fsome%2Fpath%2F", # noqa: E501 - "raw_path": None, - "root_path": "", - "scheme": "http", - "server": ("lambda-alb-123578498.us-east-2.elb.amazonaws.com", 80), - "type": "http", - } + return resp @pytest.mark.parametrize( - "method,path,multi_value_query_parameters,req_body,body_base64_encoded," - "query_string,scope_body", + "method,path,query_parameters,headers,req_body,body_base64_encoded," + "query_string,scope_body,multi_value_headers", [ - ("GET", "/hello/world", None, None, False, b"", None), + ("GET", "/hello/world", None, None, None, False, b"", None, False), + ( + "GET", + "/lambda", + { + "q1": ["1234ABCD"], + "q2": ["b+c"], # not encoded + "q3": ["b%20c"], # encoded + "q4": ["/some/path/"], # not encoded + "q5": ["%2Fsome%2Fpath%2F"], # encoded + }, + None, + "", + False, + b"q1=1234ABCD&q2=b+c&q3=b+c&q4=%2Fsome%2Fpath%2F&q5=%2Fsome%2Fpath%2F", + "", + False, + ), ( "POST", "/", {"name": ["me"]}, + None, "field1=value1&field2=value2", False, b"name=me", b"field1=value1&field2=value2", + False, + ), + # Duplicate query params with multi-value headers disabled: + ( + "POST", + "/", + {"name": ["me", "you"]}, + None, + None, + False, + b"name=you", + None, + False, ), + # Duplicate query params with multi-value headers enable: ( "GET", "/my/resource", {"name": ["me", "you"]}, None, + None, False, b"name=me&name=you", None, + True, ), ( "GET", "", {"name": ["me", "you"], "pet": ["dog"]}, None, + None, False, b"name=me&name=you&pet=dog", None, + True, ), # A 1x1 red px gif ( "POST", "/img", None, + None, b"R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", True, b"", b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + False, + ), + ( + "POST", + "/form-submit", + None, + None, + b"say=Hi&to=Mom", + False, + b"", + b"say=Hi&to=Mom", + False, ), - ("POST", "/form-submit", None, b"say=Hi&to=Mom", False, b"", b"say=Hi&to=Mom"), ], ) def test_aws_alb_scope_real( method, path, - multi_value_query_parameters, + query_parameters, + headers, req_body, body_base64_encoded, query_string, scope_body, + multi_value_headers, ): event = get_mock_aws_alb_event( - method, path, multi_value_query_parameters, req_body, body_base64_encoded + method, + path, + query_parameters, + headers, + req_body, + body_base64_encoded, + multi_value_headers, ) example_context = {} handler = AwsAlb(event, example_context) @@ -196,6 +204,7 @@ def test_aws_alb_scope_real( if scope_path == "": scope_path = "/" + assert type(handler.body) == bytes assert handler.request.scope == { "asgi": {"version": "3.0"}, "aws.context": {}, @@ -241,36 +250,8 @@ def test_aws_alb_scope_real( assert handler.body == b"" -def test_aws_alb_set_cookies_multiValueHeaders() -> None: - async def app(scope, receive, send): - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [ - [b"content-type", b"text/plain; charset=utf-8"], - [b"set-cookie", b"cookie1=cookie1; Secure"], - [b"set-cookie", b"cookie2=cookie2; Secure"], - ], - } - ) - await send({"type": "http.response.body", "body": b"Hello, world!"}) - - handler = Mangum(app, lifespan="off") - event = get_mock_aws_alb_event("GET", "/test", {}, None, False) - response = handler(event, {}) - assert response == { - "statusCode": 200, - "isBase64Encoded": False, - "headers": {"content-type": "text/plain; charset=utf-8"}, - "multiValueHeaders": { - "set-cookie": ["cookie1=cookie1; Secure", "cookie2=cookie2; Secure"], - }, - "body": "Hello, world!", - } - - -def test_aws_alb_set_cookies_headers() -> None: +@pytest.mark.parametrize("multi_value_headers_enabled", (True, False)) +def test_aws_alb_set_cookies(multi_value_headers_enabled) -> None: async def app(scope, receive, send): await send( { @@ -287,20 +268,28 @@ async def app(scope, receive, send): handler = Mangum(app, lifespan="off") event = get_mock_aws_alb_event( - "GET", "/test", {}, None, False, multi_value_headers=False + "GET", "/test", {}, None, None, False, multi_value_headers_enabled ) response = handler(event, {}) - assert response == { + + expected_response = { "statusCode": 200, "isBase64Encoded": False, - "headers": { + "body": "Hello, world!", + } + if multi_value_headers_enabled: + expected_response["multiValueHeaders"] = { + "set-cookie": ["cookie1=cookie1; Secure", "cookie2=cookie2; Secure"], + "content-type": ["text/plain; charset=utf-8"], + } + else: + expected_response["headers"] = { "content-type": "text/plain; charset=utf-8", + # Should see case mutated keys to avoid duplicate keys: "set-cookie": "cookie1=cookie1; Secure", "Set-cookie": "cookie2=cookie2; Secure", - }, - "multiValueHeaders": {}, - "body": "Hello, world!", - } + } + assert response == expected_response @pytest.mark.parametrize( @@ -332,7 +321,7 @@ async def app(scope, receive, send): ) await send({"type": "http.response.body", "body": raw_res_body}) - event = get_mock_aws_alb_event(method, "/test", {}, None, False) + event = get_mock_aws_alb_event(method, "/test", {}, None, None, False, False) handler = Mangum(app, lifespan="off") @@ -341,6 +330,5 @@ async def app(scope, receive, send): "statusCode": 200, "isBase64Encoded": res_base64_encoded, "headers": {"content-type": content_type.decode()}, - "multiValueHeaders": {}, "body": res_body, }