diff --git a/mangum/adapter.py b/mangum/adapter.py index 31d2d1f..bb99cfb 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -44,6 +44,7 @@ def __init__( api_gateway_base_path: str = "/", custom_handlers: Optional[List[Type[LambdaHandler]]] = None, text_mime_types: Optional[List[str]] = None, + exclude_headers: Optional[List[str]] = None, ) -> None: if lifespan not in ("auto", "on", "off"): raise ConfigurationError( @@ -53,9 +54,11 @@ def __init__( self.app = app self.lifespan = lifespan self.custom_handlers = custom_handlers or [] + exclude_headers = exclude_headers or [] self.config = LambdaConfig( api_gateway_base_path=api_gateway_base_path or "/", text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES], + exclude_headers=[header.lower() for header in exclude_headers], ) def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler: diff --git a/mangum/handlers/alb.py b/mangum/handlers/alb.py index 41378ed..875c4ee 100644 --- a/mangum/handlers/alb.py +++ b/mangum/handlers/alb.py @@ -5,6 +5,7 @@ from mangum.handlers.utils import ( get_server_and_port, handle_base64_response_body, + handle_exclude_headers, maybe_encode_body, ) from mangum.types import ( @@ -166,8 +167,10 @@ def __call__(self, response: Response) -> dict: # headers otherwise. multi_value_headers_enabled = "multiValueHeaders" in self.scope["aws.event"] if multi_value_headers_enabled: - out["multiValueHeaders"] = multi_value_headers + out["multiValueHeaders"] = handle_exclude_headers( + multi_value_headers, self.config + ) else: - out["headers"] = finalized_headers + out["headers"] = handle_exclude_headers(finalized_headers, self.config) return out diff --git a/mangum/handlers/api_gateway.py b/mangum/handlers/api_gateway.py index bd58a7d..d9b30c0 100644 --- a/mangum/handlers/api_gateway.py +++ b/mangum/handlers/api_gateway.py @@ -4,6 +4,7 @@ from mangum.handlers.utils import ( get_server_and_port, handle_base64_response_body, + handle_exclude_headers, handle_multi_value_headers, maybe_encode_body, strip_api_gateway_path, @@ -120,8 +121,10 @@ def __call__(self, response: Response) -> dict: return { "statusCode": response["status"], - "headers": finalized_headers, - "multiValueHeaders": multi_value_headers, + "headers": handle_exclude_headers(finalized_headers, self.config), + "multiValueHeaders": handle_exclude_headers( + multi_value_headers, self.config + ), "body": finalized_body, "isBase64Encoded": is_base64_encoded, } diff --git a/mangum/handlers/lambda_at_edge.py b/mangum/handlers/lambda_at_edge.py index 6737967..89a3709 100644 --- a/mangum/handlers/lambda_at_edge.py +++ b/mangum/handlers/lambda_at_edge.py @@ -2,6 +2,7 @@ from mangum.handlers.utils import ( handle_base64_response_body, + handle_exclude_headers, handle_multi_value_headers, maybe_encode_body, ) @@ -88,7 +89,7 @@ def __call__(self, response: Response) -> dict: return { "status": response["status"], - "headers": finalized_headers, + "headers": handle_exclude_headers(finalized_headers, self.config), "body": response_body, "isBase64Encoded": is_base64_encoded, } diff --git a/mangum/handlers/utils.py b/mangum/handlers/utils.py index c1cce0b..7e3e7b3 100644 --- a/mangum/handlers/utils.py +++ b/mangum/handlers/utils.py @@ -1,8 +1,8 @@ import base64 -from typing import Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union from urllib.parse import unquote -from mangum.types import Headers +from mangum.types import Headers, LambdaConfig def maybe_encode_body(body: Union[str, bytes], *, is_base64: bool) -> bytes: @@ -81,3 +81,15 @@ def handle_base64_response_body( is_base64_encoded = True return output_body, is_base64_encoded + + +def handle_exclude_headers( + headers: Dict[str, Any], config: LambdaConfig +) -> Dict[str, Any]: + finalized_headers = {} + for header_key, header_value in headers.items(): + if header_key in config["exclude_headers"]: + continue + finalized_headers[header_key] = header_value + + return finalized_headers diff --git a/mangum/types.py b/mangum/types.py index b50b0b2..0ff436c 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -117,6 +117,7 @@ class Response(TypedDict): class LambdaConfig(TypedDict): api_gateway_base_path: str text_mime_types: List[str] + exclude_headers: List[str] class LambdaHandler(Protocol): diff --git a/tests/handlers/test_alb.py b/tests/handlers/test_alb.py index 3804f9d..e75d2d9 100644 --- a/tests/handlers/test_alb.py +++ b/tests/handlers/test_alb.py @@ -372,3 +372,40 @@ async def app(scope, receive, send): "headers": {"content-type": content_type.decode()}, "body": utf_res_body, } + + +@pytest.mark.parametrize("multi_value_headers_enabled", (True, False)) +def test_aws_alb_exclude_headers(multi_value_headers_enabled) -> None: + async def app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain; charset=utf-8"], + [b"x-custom-header", b"test"], + ], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + handler = Mangum(app, lifespan="off", exclude_headers=["x-custom-header"]) + event = get_mock_aws_alb_event( + "GET", "/test", {}, None, None, False, multi_value_headers_enabled + ) + response = handler(event, {}) + + expected_response = { + "statusCode": 200, + "isBase64Encoded": False, + "body": "Hello, world!", + } + if multi_value_headers_enabled: + expected_response["multiValueHeaders"] = { + "content-type": ["text/plain; charset=utf-8"], + } + else: + expected_response["headers"] = { + "content-type": "text/plain; charset=utf-8", + } + assert response == expected_response diff --git a/tests/handlers/test_api_gateway.py b/tests/handlers/test_api_gateway.py index 1231bb0..e2458c2 100644 --- a/tests/handlers/test_api_gateway.py +++ b/tests/handlers/test_api_gateway.py @@ -401,3 +401,31 @@ async def app(scope, receive, send): "multiValueHeaders": {}, "body": utf_res_body, } + + +def test_aws_api_gateway_exclude_headers(): + async def app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain; charset=utf-8"], + [b"x-custom-header", b"test"], + ], + } + ) + await send({"type": "http.response.body", "body": b"Hello world"}) + + event = get_mock_aws_api_gateway_event("GET", "/test", {}, None, False) + + handler = Mangum(app, lifespan="off", exclude_headers=["X-CUSTOM-HEADER"]) + + response = handler(event, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": {"content-type": b"text/plain; charset=utf-8".decode()}, + "multiValueHeaders": {}, + "body": "Hello world", + } diff --git a/tests/handlers/test_lambda_at_edge.py b/tests/handlers/test_lambda_at_edge.py index ffeb9bc..563e144 100644 --- a/tests/handlers/test_lambda_at_edge.py +++ b/tests/handlers/test_lambda_at_edge.py @@ -342,3 +342,34 @@ async def app(scope, receive, send): }, "body": utf_res_body, } + + +def test_aws_lambda_at_edge_exclude_(): + async def app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain; charset=utf-8"], + [b"x-custom-header", b"test"], + ], + } + ) + await send({"type": "http.response.body", "body": b"Hello world"}) + + event = mock_lambda_at_edge_event("GET", "/test", {}, None, False) + + handler = Mangum(app, lifespan="off", exclude_headers=["x-custom-header"]) + + response = handler(event, {}) + assert response == { + "status": 200, + "isBase64Encoded": False, + "headers": { + "content-type": [ + {"key": "content-type", "value": b"text/plain; charset=utf-8".decode()} + ] + }, + "body": "Hello world", + } diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 6b50fd6..de36049 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -14,6 +14,7 @@ def test_default_settings(): assert handler.lifespan == "auto" assert handler.config["api_gateway_base_path"] == "/" assert sorted(handler.config["text_mime_types"]) == sorted(DEFAULT_TEXT_MIME_TYPES) + assert handler.config["exclude_headers"] == [] @pytest.mark.parametrize(