diff --git a/mangum/handlers/aws_http_gateway.py b/mangum/handlers/aws_http_gateway.py index 2d8982da..b023d69c 100644 --- a/mangum/handlers/aws_http_gateway.py +++ b/mangum/handlers/aws_http_gateway.py @@ -1,6 +1,6 @@ import base64 import urllib.parse -from typing import Dict, Any +from typing import Dict, Any, List, Tuple from . import AwsApiGateway from .. import Response, Request @@ -122,37 +122,65 @@ def transform_response(self, response: Response) -> Dict[str, Any]: https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.response """ + if self.event_version == "1.0": + return self.transform_response_v1(response) + elif self.event_version == "2.0": + return self.transform_response_v2(response) + raise RuntimeError( # pragma: no cover + "Misconfigured event unable to return value, unsupported version." + ) + + def transform_response_v1(self, response: Response) -> Dict[str, Any]: headers, multi_value_headers = self._handle_multi_value_headers( response.headers ) - if self.event_version == "1.0": - body, is_base64_encoded = self._handle_base64_response_body( - response.body, headers - ) - return { - "statusCode": response.status, - "headers": headers, - "multiValueHeaders": multi_value_headers, - "body": body, - "isBase64Encoded": is_base64_encoded, - } - elif self.event_version == "2.0": - # The API Gateway will infer stuff for us, but we'll just do that inference - # here and keep the output consistent - if "content-type" not in headers and response.body is not None: - headers["content-type"] = "application/json" + body, is_base64_encoded = self._handle_base64_response_body( + response.body, headers + ) + return { + "statusCode": response.status, + "headers": headers, + "multiValueHeaders": multi_value_headers, + "body": body, + "isBase64Encoded": is_base64_encoded, + } + + def _combine_headers_v2( + self, input_headers: List[List[bytes]] + ) -> Tuple[Dict[str, str], List[str]]: + output_headers: Dict[str, str] = {} + cookies: List[str] = [] + for key, value in input_headers: + normalized_key: str = key.decode().lower() + normalized_value: str = value.decode() + if normalized_key == "set-cookie": + cookies.append(normalized_value) + else: + if normalized_key in output_headers: + normalized_value = ( + f"{output_headers[normalized_key]},{normalized_value}" + ) + output_headers[normalized_key] = normalized_value + return output_headers, cookies - body, is_base64_encoded = self._handle_base64_response_body( - response.body, headers - ) - return { - "statusCode": response.status, - "headers": headers, - "multiValueHeaders": multi_value_headers, - "body": body, - "isBase64Encoded": is_base64_encoded, - } - raise RuntimeError( # pragma: no cover - "Misconfigured event unable to return value, unsupported version." + def transform_response_v2(self, response_in: Response) -> Dict[str, Any]: + # The API Gateway will infer stuff for us, but we'll just do that inference + # here and keep the output consistent + + headers, cookies = self._combine_headers_v2(response_in.headers) + + if "content-type" not in headers and response_in.body is not None: + headers["content-type"] = "application/json" + + body, is_base64_encoded = self._handle_base64_response_body( + response_in.body, headers ) + response_out = { + "statusCode": response_in.status, + "body": body, + "headers": headers or None, + "cookies": cookies or None, + "isBase64Encoded": is_base64_encoded, + } + return {key: value for key, value in response_out.items() if value is not None} diff --git a/tests/conftest.py b/tests/conftest.py index c4ad042e..ab04eafc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,7 +65,7 @@ def mock_aws_api_gateway_event(request): @pytest.fixture -def mock_http_api_event(request): +def mock_http_api_event_v2(request): method = request.param[0] body = request.param[1] multi_value_query_parameters = request.param[2] @@ -120,6 +120,67 @@ def mock_http_api_event(request): return event +@pytest.fixture +def mock_http_api_event_v1(request): + method = request.param[0] + body = request.param[1] + multi_value_query_parameters = request.param[2] + query_string = request.param[3] + event = { + "version": "1.0", + "routeKey": "$default", + "rawPath": "/my/path", + "path": "/my/path", + "httpMethod": method, + "rawQueryString": query_string, + "cookies": ["cookie1", "cookie2"], + "headers": { + "accept-encoding": "gzip,deflate", + "x-forwarded-port": "443", + "x-forwarded-proto": "https", + "host": "test.execute-api.us-west-2.amazonaws.com", + }, + "queryStringParameters": { + k: v[-1] for k, v in multi_value_query_parameters.items() + } + if multi_value_query_parameters + else None, + "multiValueQueryStringParameters": { + k: v for k, v in multi_value_query_parameters.items() + } + if multi_value_query_parameters + else None, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authorizer": { + "jwt": { + "claims": {"claim1": "value1", "claim2": "value2"}, + "scopes": ["scope1", "scope2"], + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "protocol": "HTTP/1.1", + "sourceIp": "192.168.100.1", + "userAgent": "agent", + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390, + }, + "body": body, + "pathParameters": {"parameter1": "value1"}, + "isBase64Encoded": False, + "stageVariables": {"stageVariable1": "value1", "stageVariable2": "value2"}, + } + + return event + + @pytest.fixture def mock_lambda_at_edge_event(request): method = request.param[0] diff --git a/tests/handlers/test_aws_http_gateway.py b/tests/handlers/test_aws_http_gateway.py index 4e87f9df..91f8ff84 100644 --- a/tests/handlers/test_aws_http_gateway.py +++ b/tests/handlers/test_aws_http_gateway.py @@ -605,6 +605,5 @@ async def app(scope, receive, send): "statusCode": 200, "isBase64Encoded": res_base64_encoded, "headers": {"content-type": content_type.decode()}, - "multiValueHeaders": {}, "body": res_body, } diff --git a/tests/test_http.py b/tests/test_http.py index c82efd8c..4328aba3 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -253,7 +253,7 @@ async def app(scope, receive, send): @pytest.mark.parametrize( - "mock_http_api_event", + "mock_http_api_event_v2", [ (["GET", None, None, ""]), (["GET", None, {"name": ["me"]}, "name=me"]), @@ -267,9 +267,9 @@ async def app(scope, receive, send): ] ), ], - indirect=["mock_http_api_event"], + indirect=["mock_http_api_event_v2"], ) -def test_set_cookies(mock_http_api_event) -> None: +def test_set_cookies_v2(mock_http_api_event_v2) -> None: async def app(scope, receive, send): assert scope == { "asgi": {"version": "3.0"}, @@ -279,7 +279,7 @@ async def app(scope, receive, send): "version": "2.0", "routeKey": "$default", "rawPath": "/my/path", - "rawQueryString": mock_http_api_event["rawQueryString"], + "rawQueryString": mock_http_api_event_v2["rawQueryString"], "cookies": ["cookie1", "cookie2"], "headers": { "accept-encoding": "gzip,deflate", @@ -287,7 +287,9 @@ async def app(scope, receive, send): "x-forwarded-proto": "https", "host": "test.execute-api.us-west-2.amazonaws.com", }, - "queryStringParameters": mock_http_api_event["queryStringParameters"], + "queryStringParameters": mock_http_api_event_v2[ + "queryStringParameters" + ], "requestContext": { "accountId": "123456789012", "apiId": "api-id", @@ -331,7 +333,127 @@ async def app(scope, receive, send): "http_version": "1.1", "method": "GET", "path": "/my/path", - "query_string": mock_http_api_event["rawQueryString"].encode(), + "query_string": mock_http_api_event_v2["rawQueryString"].encode(), + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("test.execute-api.us-west-2.amazonaws.com", 443), + "type": "http", + } + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain; charset=utf-8"], + [b"set-cookie", b"cookie1=cookie1; Secure"], + [b"set-cookie", b"cookie2=cookie2; Secure"], + [b"multivalue", b"foo"], + [b"multivalue", b"bar"], + ], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + handler = Mangum(app, lifespan="off") + response = handler(mock_http_api_event_v2, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": { + "content-type": "text/plain; charset=utf-8", + "multivalue": "foo,bar", + }, + "cookies": ["cookie1=cookie1; Secure", "cookie2=cookie2; Secure"], + "body": "Hello, world!", + } + + +@pytest.mark.parametrize( + "mock_http_api_event_v1", + [ + (["GET", None, None, ""]), + (["GET", None, {"name": ["me"]}, "name=me"]), + (["GET", None, {"name": ["me", "you"]}, "name=me&name=you"]), + ( + [ + "GET", + None, + {"name": ["me", "you"], "pet": ["dog"]}, + "name=me&name=you&pet=dog", + ] + ), + ], + indirect=["mock_http_api_event_v1"], +) +def test_set_cookies_v1(mock_http_api_event_v1) -> None: + async def app(scope, receive, send): + assert scope == { + "asgi": {"version": "3.0"}, + "aws.eventType": "AWS_HTTP_GATEWAY", + "aws.context": {}, + "aws.event": { + "version": "1.0", + "routeKey": "$default", + "rawPath": "/my/path", + "path": "/my/path", + "httpMethod": "GET", + "rawQueryString": mock_http_api_event_v1["rawQueryString"], + "cookies": ["cookie1", "cookie2"], + "headers": { + "accept-encoding": "gzip,deflate", + "x-forwarded-port": "443", + "x-forwarded-proto": "https", + "host": "test.execute-api.us-west-2.amazonaws.com", + }, + "queryStringParameters": mock_http_api_event_v1[ + "queryStringParameters" + ], + "multiValueQueryStringParameters": mock_http_api_event_v1[ + "multiValueQueryStringParameters" + ], + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authorizer": { + "jwt": { + "claims": {"claim1": "value1", "claim2": "value2"}, + "scopes": ["scope1", "scope2"], + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "protocol": "HTTP/1.1", + "sourceIp": "192.168.100.1", + "userAgent": "agent", + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1_583_348_638_390, + }, + "body": None, + "pathParameters": {"parameter1": "value1"}, + "isBase64Encoded": False, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2", + }, + }, + "client": (None, 0), + "headers": [ + [b"accept-encoding", b"gzip,deflate"], + [b"x-forwarded-port", b"443"], + [b"x-forwarded-proto", b"https"], + [b"host", b"test.execute-api.us-west-2.amazonaws.com"], + ], + "http_version": "1.1", + "method": "GET", + "path": "/my/path", + "query_string": mock_http_api_event_v1["rawQueryString"].encode(), "raw_path": None, "root_path": "", "scheme": "https", @@ -353,7 +475,7 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) handler = Mangum(app, lifespan="off") - response = handler(mock_http_api_event, {}) + response = handler(mock_http_api_event_v1, {}) assert response == { "statusCode": 200, "isBase64Encoded": False,