diff --git a/docs/asgi-frameworks.md b/docs/asgi-frameworks.md index 9fcd057e..907886f5 100644 --- a/docs/asgi-frameworks.md +++ b/docs/asgi-frameworks.md @@ -11,14 +11,15 @@ We can think about the ASGI framework support without referencing an existing im Let's invent an API for a non-existent microframework to demonstrate things further. This could represent *any* ASGI framework application: ```python +import mangum.adapter import framework -from mangum import Mangum +from mangum import Mangum, Request app = framework.applications.Application() @app.route("/") -def endpoint(request: framework.requests.Request) -> dict: +def endpoint(request: Request) -> dict: return {"hi": "there"} diff --git a/docs/http.md b/docs/http.md index 99bfda6a..24405366 100644 --- a/docs/http.md +++ b/docs/http.md @@ -1,7 +1,16 @@ # HTTP -Mangum provides support for both [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) and the newer [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) APIs in API Gateway. It also includes configurable binary response support. - +Mangum provides support for the following AWS HTTP Lambda Event Source: + + * [API Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) + ([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/services-apigateway.html)) + * [HTTP Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) + ([Event Examples](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html)) + * [Application Load Balancer (ALB)](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/introduction.html) + ([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html)) + * [CloudFront Lambda@Edge](https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/lambda-at-the-edge.html) + ([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/lambda-edge.html)) + ```python from fastapi import FastAPI from fastapi.middleware.gzip import GZipMiddleware diff --git a/mangum/__init__.py b/mangum/__init__.py index 18ba9b65..a81cdfb0 100644 --- a/mangum/__init__.py +++ b/mangum/__init__.py @@ -1 +1,4 @@ +from .types import Request, Response from .adapter import Mangum # noqa: F401 + +__all__ = ["Mangum", "Request", "Response"] diff --git a/mangum/adapter.py b/mangum/adapter.py index 0ff68a5b..1103ab81 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -1,36 +1,31 @@ -import base64 -import typing import logging -import urllib.parse - -from dataclasses import dataclass, InitVar from contextlib import ExitStack - -from mangum.types import ASGIApp, Scope -from mangum.protocols.lifespan import LifespanCycle -from mangum.protocols.http import HTTPCycle -from mangum.exceptions import ConfigurationError - -if typing.TYPE_CHECKING: # pragma: no cover +from typing import ( + Any, + ContextManager, + Dict, + TYPE_CHECKING, +) + +from .exceptions import ConfigurationError +from .handlers import AbstractHandler +from .protocols import HTTPCycle, LifespanCycle +from .types import ASGIApp + +if TYPE_CHECKING: # pragma: no cover from awslambdaric.lambda_context import LambdaContext DEFAULT_TEXT_MIME_TYPES = [ + "text/", "application/json", "application/javascript", "application/xml", "application/vnd.api+json", ] -LOG_LEVELS = { - "critical": logging.CRITICAL, - "error": logging.ERROR, - "warning": logging.WARNING, - "info": logging.INFO, - "debug": logging.DEBUG, -} +logger = logging.getLogger("mangum") -@dataclass class Mangum: """ Creates an adapter instance. @@ -41,153 +36,40 @@ class Mangum: and `off`. Default is `auto`. * **log_level** - A string to configure the log level. Choices are: `info`, `critical`, `error`, `warning`, and `debug`. Default is `info`. - * **api_gateway_base_path** - Base path to strip from URL when using a custom - domain name. * **text_mime_types** - A list of MIME types to include with the defaults that should not return a binary response in API Gateway. """ app: ASGIApp lifespan: str = "auto" - log_level: str = "info" - api_gateway_base_path: typing.Optional[str] = None - text_mime_types: InitVar[typing.Optional[typing.List[str]]] = None - def __post_init__(self, text_mime_types: typing.Optional[typing.List[str]]) -> None: + def __init__( + self, + app: ASGIApp, + lifespan: str = "auto", + **handler_kwargs: Dict[str, Any], + ): + self.app = app + self.lifespan = lifespan + self.handler_kwargs = handler_kwargs + if self.lifespan not in ("auto", "on", "off"): raise ConfigurationError( "Invalid argument supplied for `lifespan`. Choices are: auto|on|off" ) - if self.log_level not in ("critical", "error", "warning", "info", "debug"): - raise ConfigurationError( - "Invalid argument supplied for `log_level`. " - "Choices are: critical|error|warning|info|debug" - ) - - self.logger = logging.getLogger("mangum") - self.logger.setLevel(LOG_LEVELS[self.log_level]) - - should_prefix_base_path = ( - self.api_gateway_base_path - and not self.api_gateway_base_path.startswith("/") - ) - if should_prefix_base_path: - self.api_gateway_base_path = f"/{self.api_gateway_base_path}" - - if text_mime_types: - text_mime_types += DEFAULT_TEXT_MIME_TYPES - else: - text_mime_types = DEFAULT_TEXT_MIME_TYPES - self.text_mime_types = text_mime_types - def __call__(self, event: dict, context: "LambdaContext") -> dict: - self.logger.debug("Event received.") + logger.debug("Event received.") with ExitStack() as stack: if self.lifespan != "off": - lifespan_cycle: typing.ContextManager = LifespanCycle( - self.app, self.lifespan - ) + lifespan_cycle: ContextManager = LifespanCycle(self.app, self.lifespan) stack.enter_context(lifespan_cycle) - is_binary = event.get("isBase64Encoded", False) - initial_body = event.get("body") or b"" - if is_binary: - initial_body = base64.b64decode(initial_body) - elif not isinstance(initial_body, bytes): - initial_body = initial_body.encode() - - scope = self.create_scope(event, context) - http_cycle = HTTPCycle(scope, text_mime_types=self.text_mime_types) - response = http_cycle(self.app, initial_body) - - return response - - def create_scope(self, event: dict, context: "LambdaContext") -> Scope: - """ - Creates a scope object according to ASGI specification from a Lambda Event. - - https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope - - The event comes from various sources: AWS ALB, AWS API Gateway of different - versions and configurations(multivalue header, etc). - Thus, some heuristics is applied to guess an event type. - - """ - request_context = event["requestContext"] - - if event.get("multiValueHeaders"): - headers = { - k.lower(): ", ".join(v) if isinstance(v, list) else "" - for k, v in event.get("multiValueHeaders", {}).items() - } - elif event.get("headers"): - headers = {k.lower(): v for k, v in event.get("headers", {}).items()} - else: - headers = {} - - # API Gateway v2 - if event.get("version") == "2.0": - source_ip = request_context["http"]["sourceIp"] - path = request_context["http"]["path"] - http_method = request_context["http"]["method"] - query_string = event.get("rawQueryString", "").encode() - - if event.get("cookies"): - headers["cookie"] = "; ".join(event.get("cookies", [])) - - # API Gateway v1 / ELB - else: - if "elb" in request_context: - # NOTE: trust only the most right side value - source_ip = headers.get("x-forwarded-for", "").split(", ")[-1] - else: - source_ip = request_context.get("identity", {}).get("sourceIp") - - path = event["path"] - http_method = event["httpMethod"] - - if event.get("multiValueQueryStringParameters"): - query_string = urllib.parse.urlencode( - event.get("multiValueQueryStringParameters", {}), doseq=True - ).encode() - elif event.get("queryStringParameters"): - query_string = urllib.parse.urlencode( - event.get("queryStringParameters", {}) - ).encode() - else: - query_string = b"" - - server_name = headers.get("host", "mangum") - if ":" not in server_name: - server_port = headers.get("x-forwarded-port", 80) - else: - server_name, server_port = server_name.split(":") # pragma: no cover - server = (server_name, int(server_port)) - client = (source_ip, 0) - - if not path: # pragma: no cover - path = "/" - elif self.api_gateway_base_path: - if path.startswith(self.api_gateway_base_path): - path = path[len(self.api_gateway_base_path) :] - - scope = { - "type": "http", - "http_version": "1.1", - "method": http_method, - "headers": [[k.encode(), v.encode()] for k, v in headers.items()], - "path": urllib.parse.unquote(path), - "raw_path": None, - "root_path": "", - "scheme": headers.get("x-forwarded-proto", "https"), - "query_string": query_string, - "server": server, - "client": client, - "asgi": {"version": "3.0"}, - "aws.event": event, - "aws.context": context, - } + handler = AbstractHandler.from_trigger( + event, context, **self.handler_kwargs + ) + http_cycle = HTTPCycle(handler.scope) + response = http_cycle(self.app, handler.body) - return scope + return handler.transform_response(response) diff --git a/mangum/handlers/__init__.py b/mangum/handlers/__init__.py new file mode 100644 index 00000000..9ab3303d --- /dev/null +++ b/mangum/handlers/__init__.py @@ -0,0 +1,13 @@ +from .abstract_handler import AbstractHandler +from .aws_alb import AwsAlb +from .aws_api_gateway import AwsApiGateway +from .aws_cf_lambda_at_edge import AwsCfLambdaAtEdge +from .aws_http_gateway import AwsHttpGateway + +__all__ = [ + "AbstractHandler", + "AwsAlb", + "AwsApiGateway", + "AwsCfLambdaAtEdge", + "AwsHttpGateway", +] diff --git a/mangum/handlers/abstract_handler.py b/mangum/handlers/abstract_handler.py new file mode 100644 index 00000000..9e613554 --- /dev/null +++ b/mangum/handlers/abstract_handler.py @@ -0,0 +1,134 @@ +import base64 +from abc import ABCMeta, abstractmethod +from typing import Dict, Any, TYPE_CHECKING, Tuple, List + +from .. import Response, Request + +if TYPE_CHECKING: # pragma: no cover + from awslambdaric.lambda_context import LambdaContext + + +class AbstractHandler(metaclass=ABCMeta): + def __init__( + self, + trigger_event: Dict[str, Any], + trigger_context: "LambdaContext", + **kwargs: Dict[str, Any] + ): + self.trigger_event = trigger_event + self.trigger_context = trigger_context + + @property + @abstractmethod + def scope(self) -> Request: + """ + Parse an ASGI scope from the request event + """ + + @property + @abstractmethod + def body(self) -> bytes: + """ + Get the raw body from the request event + """ + + @abstractmethod + def transform_response(self, response: Response) -> Dict[str, Any]: + """ + After running our application, transform the response to the correct format for + this handler + """ + + @staticmethod + def from_trigger( + trigger_event: Dict[str, Any], + trigger_context: "LambdaContext", + **kwargs: Dict[str, Any] + ) -> "AbstractHandler": + """ + A factory method that determines which handler to use. All this code should + probably stay in one place to make sure we are able to uniquely find each + handler correctly. + """ + + # These should be ordered from most specific to least for best accuracy + if ( + "requestContext" in trigger_event + and "elb" in trigger_event["requestContext"] + ): + from . import AwsAlb + + return AwsAlb(trigger_event, trigger_context, **kwargs) + + if ( + "Records" in trigger_event + and len(trigger_event["Records"]) > 0 + and "cf" in trigger_event["Records"][0] + ): + from . import AwsCfLambdaAtEdge + + return AwsCfLambdaAtEdge(trigger_event, trigger_context, **kwargs) + + if "version" in trigger_event and "requestContext" in trigger_event: + from . import AwsHttpGateway + + return AwsHttpGateway(trigger_event, trigger_context, **kwargs) + + if "resource" in trigger_event: + from . import AwsApiGateway + + return AwsApiGateway( + trigger_event, trigger_context, **kwargs # type: ignore + ) + + raise TypeError("Unable to determine handler from trigger event") + + @staticmethod + def _handle_multi_value_headers( + response_headers: List[List[bytes]], + ) -> Tuple[Dict[str, str], Dict[str, List[str]]]: + headers: Dict[str, str] = {} + multi_value_headers: Dict[str, List[str]] = {} + for key, value in response_headers: + lower_key = key.decode().lower() + if lower_key in multi_value_headers: + multi_value_headers[lower_key].append(value.decode()) + elif lower_key in headers: + # Move existing to multi_value_headers and append current + multi_value_headers[lower_key] = [ + headers[lower_key], + value.decode(), + ] + del headers[lower_key] + else: + headers[lower_key] = value.decode() + return headers, multi_value_headers + + @staticmethod + def _handle_base64_response_body( + body: bytes, headers: Dict[str, str] + ) -> Tuple[str, bool]: + """ + To ease debugging for our users, try and return strings where we can, + otherwise to ensure maximum compatibility with binary data, base64 encode it. + """ + is_base64_encoded = False + output_body = "" + if body != b"": + from ..adapter import DEFAULT_TEXT_MIME_TYPES + + for text_mime_type in DEFAULT_TEXT_MIME_TYPES: + if text_mime_type in headers.get("content-type", ""): + try: + output_body = body.decode() + except UnicodeDecodeError: + # Can't decode it, base64 it and be done + output_body = base64.b64encode(body).decode() + is_base64_encoded = True + break + else: + # Not text, base64 encode + output_body = base64.b64encode(body).decode() + is_base64_encoded = True + + return output_body, is_base64_encoded diff --git a/mangum/handlers/aws_alb.py b/mangum/handlers/aws_alb.py new file mode 100644 index 00000000..69db2475 --- /dev/null +++ b/mangum/handlers/aws_alb.py @@ -0,0 +1,80 @@ +import base64 +import urllib.parse +from typing import Dict, Any + +from .abstract_handler import AbstractHandler +from .. import Response, Request + + +class AwsAlb(AbstractHandler): + """ + Handles AWS Elastic Load Balancer, really Application Load Balancer events + transforming them into ASGI Scope and handling responses + + See: https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html + """ + + TYPE = "AWS_ALB" + + @property + def scope(self) -> Request: + event = self.trigger_event + + headers = {} + if event.get("headers"): + headers = {k.lower(): v for k, v in event.get("headers", {}).items()} + + source_ip = headers.get("x-forwarded-for", "") + path = event["path"] + http_method = event["httpMethod"] + query_string = urllib.parse.urlencode( + event.get("queryStringParameters", {}), doseq=True + ).encode() + + server_name = headers.get("host", "mangum") + if ":" not in server_name: + server_port = headers.get("x-forwarded-port", 80) + else: + server_name, server_port = server_name.split(":") # pragma: no cover + server = (server_name, int(server_port)) + client = (source_ip, 0) + + if not path: + path = "/" + + return Request( + method=http_method, + headers=[[k.encode(), v.encode()] for k, v in headers.items()], + path=urllib.parse.unquote(path), + scheme=headers.get("x-forwarded-proto", "https"), + query_string=query_string, + server=server, + client=client, + trigger_event=self.trigger_event, + trigger_context=self.trigger_context, + event_type=self.TYPE, + ) + + @property + def body(self) -> bytes: + body = self.trigger_event.get("body", b"") + if self.trigger_event.get("isBase64Encoded", False): + body = base64.b64decode(body) + return body + + def transform_response(self, response: Response) -> Dict[str, Any]: + headers, multi_value_headers = self._handle_multi_value_headers( + response.headers + ) + + body, is_base64_encoded = self._handle_base64_response_body( + response.body, headers + ) + + return { + "statusCode": response.status, + "headers": headers, + "multiValueHeaders": multi_value_headers, + "body": body, + "isBase64Encoded": is_base64_encoded, + } diff --git a/mangum/handlers/aws_api_gateway.py b/mangum/handlers/aws_api_gateway.py new file mode 100644 index 00000000..c1aa5de4 --- /dev/null +++ b/mangum/handlers/aws_api_gateway.py @@ -0,0 +1,117 @@ +import base64 +import urllib.parse +from typing import Dict, Any, TYPE_CHECKING + +from .abstract_handler import AbstractHandler +from .. import Response, Request + +if TYPE_CHECKING: # pragma: no cover + from awslambdaric.lambda_context import LambdaContext + + +class AwsApiGateway(AbstractHandler): + """ + Handles AWS API Gateway events, transforming them into ASGI Scope and handling + responses + + See: https://docs.aws.amazon.com/lambda/latest/dg/services-apigateway.html + """ + + TYPE = "AWS_API_GATEWAY" + + def __init__( + self, + trigger_event: Dict[str, Any], + trigger_context: "LambdaContext", + base_path: str = "/", + **kwargs: Dict[str, Any], # type: ignore + ): + super().__init__(trigger_event, trigger_context, **kwargs) + self.base_path = base_path + + @property + def scope(self) -> Request: + event = self.trigger_event + + # multiValue versions of headers take precedence over their plain versions + # https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format + if event.get("multiValueHeaders"): + headers = { + k.lower(): ", ".join(v) if isinstance(v, list) else "" + for k, v in event.get("multiValueHeaders", {}).items() + } + elif event.get("headers"): + headers = {k.lower(): v for k, v in event.get("headers", {}).items()} + else: + headers = {} + + request_context = event["requestContext"] + + source_ip = request_context.get("identity", {}).get("sourceIp") + + path = event["path"] + http_method = event["httpMethod"] + + if event.get("multiValueQueryStringParameters"): + query_string = urllib.parse.urlencode( + event.get("multiValueQueryStringParameters", {}), doseq=True + ).encode() + elif event.get("queryStringParameters"): + query_string = urllib.parse.urlencode( + event.get("queryStringParameters", {}) + ).encode() + else: + query_string = b"" + + server_name = headers.get("host", "mangum") + if ":" not in server_name: + server_port = headers.get("x-forwarded-port", 80) + else: + server_name, server_port = server_name.split(":") # pragma: no cover + server = (server_name, int(server_port)) + client = (source_ip, 0) + + if not path: + path = "/" + elif self.base_path and self.base_path != "/": + if not self.base_path.startswith("/"): + self.base_path = f"/{self.base_path}" + if path.startswith(self.base_path): + path = path[len(self.base_path) :] + + return Request( + method=http_method, + headers=[[k.encode(), v.encode()] for k, v in headers.items()], + path=urllib.parse.unquote(path), + scheme=headers.get("x-forwarded-proto", "https"), + query_string=query_string, + server=server, + client=client, + trigger_event=self.trigger_event, + trigger_context=self.trigger_context, + event_type=self.TYPE, + ) + + @property + def body(self) -> bytes: + body = self.trigger_event.get("body", b"") + if self.trigger_event.get("isBase64Encoded", False): + body = base64.b64decode(body) + return body + + def transform_response(self, response: Response) -> Dict[str, Any]: + headers, multi_value_headers = self._handle_multi_value_headers( + response.headers + ) + + body, is_base64_encoded = self._handle_base64_response_body( + response.body, headers + ) + + return { + "statusCode": response.status, + "headers": headers, + "multiValueHeaders": multi_value_headers, + "body": body, + "isBase64Encoded": is_base64_encoded, + } diff --git a/mangum/handlers/aws_cf_lambda_at_edge.py b/mangum/handlers/aws_cf_lambda_at_edge.py new file mode 100644 index 00000000..6fb77e93 --- /dev/null +++ b/mangum/handlers/aws_cf_lambda_at_edge.py @@ -0,0 +1,79 @@ +import base64 +from typing import Dict, Any, List + +from .abstract_handler import AbstractHandler +from .. import Response, Request + + +class AwsCfLambdaAtEdge(AbstractHandler): + """ + Handles AWS Elastic Load Balancer, really Application Load Balancer events + transforming them into ASGI Scope and handling responses + + See: https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/lambda-event-structure.html # noqa: E501 + """ + + TYPE = "AWS_CF_LAMBDA_AT_EDGE" + + @property + def scope(self) -> Request: + event = self.trigger_event + + cf_request = event["Records"][0]["cf"]["request"] + + scheme_header = cf_request["headers"].get("cloudfront-forwarded-proto", [{}]) + scheme = scheme_header[0].get("value", "https") + + host_header = cf_request["headers"].get("host", [{}]) + server_name = host_header[0].get("value", "mangum") + if ":" not in server_name: + forwarded_port_header = cf_request["headers"].get("x-forwarded-port", [{}]) + server_port = forwarded_port_header[0].get("value", 80) + else: + server_name, server_port = server_name.split(":") # pragma: no cover + server = (server_name, int(server_port)) + + source_ip = cf_request["clientIp"] + client = (source_ip, 0) + + return Request( + method=cf_request["method"], + headers=[ + [k.encode(), v[0]["value"].encode()] + for k, v in cf_request["headers"].items() + ], + path=cf_request["uri"], + scheme=scheme, + query_string=cf_request["querystring"].encode(), + server=server, + client=client, + trigger_event=self.trigger_event, + trigger_context=self.trigger_context, + event_type=self.TYPE, + ) + + @property + def body(self) -> bytes: + request = self.trigger_event["Records"][0]["cf"]["request"] + body = request.get("body", {}).get("data", None) + if request.get("body", {}).get("encoding", "") == "base64": + body = base64.b64decode(body) + return body + + def transform_response(self, response: Response) -> Dict[str, Any]: + headers_dict, _ = self._handle_multi_value_headers(response.headers) + body, is_base64_encoded = self._handle_base64_response_body( + response.body, headers_dict + ) + + # Expand headers to weird list of Dict[str, List[Dict[str, str]]] + headers_expanded: Dict[str, List[Dict[str, str]]] = { + key.decode().lower(): [{"key": key.decode().lower(), "value": val.decode()}] + for key, val in response.headers + } + return { + "status": response.status, + "headers": headers_expanded, + "body": body, + "isBase64Encoded": is_base64_encoded, + } diff --git a/mangum/handlers/aws_http_gateway.py b/mangum/handlers/aws_http_gateway.py new file mode 100644 index 00000000..6227c9a2 --- /dev/null +++ b/mangum/handlers/aws_http_gateway.py @@ -0,0 +1,152 @@ +import base64 +import urllib.parse +from typing import Dict, Any + +from .abstract_handler import AbstractHandler +from .. import Response, Request + + +class AwsHttpGateway(AbstractHandler): + """ + Handles AWS HTTP Gateway events (v1.0 and v2.0), transforming them into ASGI Scope + and handling responses + + See: https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html # noqa: E501 + """ + + TYPE = "AWS_HTTP_GATEWAY" + + @property + def event_version(self) -> str: + return self.trigger_event.get("version", "") + + @property + def scope(self) -> Request: + event = self.trigger_event + + headers = {} + if event.get("headers"): + headers = {k.lower(): v for k, v in event.get("headers", {}).items()} + + request_context = event["requestContext"] + + # API Gateway v2 + if self.event_version == "2.0": + source_ip = request_context["http"]["sourceIp"] + path = request_context["http"]["path"] + http_method = request_context["http"]["method"] + query_string = event.get("rawQueryString", "").encode() + + if event.get("cookies"): + headers["cookie"] = "; ".join(event.get("cookies", [])) + + # API Gateway v1 + elif self.event_version == "1.0": + # v1.0 of the HTTP Gateway supports multiValueHeaders + if event.get("multiValueHeaders"): + headers.update( + { + k.lower(): ", ".join(v) if isinstance(v, list) else "" + for k, v in event.get("multiValueHeaders", {}).items() + } + ) + + source_ip = request_context.get("identity", {}).get("sourceIp") + + path = event["path"] + http_method = event["httpMethod"] + + # AWS Blog Post on this: + # https://aws.amazon.com/blogs/compute/support-for-multi-value-parameters-in-amazon-api-gateway/ # noqa: E501 + # A multi value param will be in multi value _and_ regular + # queryStringParameters. Multi value takes precedence. + if event.get("multiValueQueryStringParameters", False): + query_string = urllib.parse.urlencode( + event.get("multiValueQueryStringParameters", {}), doseq=True + ).encode() + elif event.get("queryStringParameters", False): + query_string = urllib.parse.urlencode( + event.get("queryStringParameters", {}) + ).encode() + else: + query_string = b"" + else: + raise RuntimeError( + "Unsupported version of HTTP Gateway Spec, only v1.0 and v2.0 are " + "supported." + ) + + server_name = headers.get("host", "mangum") + if ":" not in server_name: + server_port = headers.get("x-forwarded-port", 80) + else: + server_name, server_port = server_name.split(":") # pragma: no cover + server = (server_name, int(server_port)) + client = (source_ip, 0) + + if not path: + path = "/" + + return Request( + method=http_method, + headers=[[k.encode(), v.encode()] for k, v in headers.items()], + path=urllib.parse.unquote(path), + scheme=headers.get("x-forwarded-proto", "https"), + query_string=query_string, + server=server, + client=client, + trigger_event=self.trigger_event, + trigger_context=self.trigger_context, + event_type=self.TYPE, + ) + + @property + def body(self) -> bytes: + body = self.trigger_event.get("body", b"") + if self.trigger_event.get("isBase64Encoded", False): + body = base64.b64decode(body) + return body + + def transform_response(self, response: Response) -> Dict[str, Any]: + """ + This handles some unnecessary magic from AWS + + > API Gateway can infer the response format for you + Boooooo + + https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.response + """ + headers, multi_value_headers = self._handle_multi_value_headers( + response.headers + ) + + if self.event_version == "1.0": + body, is_base64_encoded = self._handle_base64_response_body( + response.body, headers + ) + return { + "statusCode": response.status, + "headers": headers, + "multiValueHeaders": multi_value_headers, + "body": body, + "isBase64Encoded": is_base64_encoded, + } + elif self.event_version == "2.0": + # The API Gateway will infer stuff for us, but we'll just do that inference + # here and keep the output consistent + if "content-type" not in headers and response.body is not None: + headers["content-type"] = "application/json" + + body, is_base64_encoded = self._handle_base64_response_body( + response.body, headers + ) + return { + "statusCode": response.status, + "headers": headers, + "multiValueHeaders": multi_value_headers, + "body": body, + "isBase64Encoded": is_base64_encoded, + } + raise RuntimeError( # pragma: no cover + "Misconfigured event unable to return value, unsupported version." + ) diff --git a/mangum/protocols/__init__.py b/mangum/protocols/__init__.py index e69de29b..7c918bf8 100644 --- a/mangum/protocols/__init__.py +++ b/mangum/protocols/__init__.py @@ -0,0 +1,8 @@ +from .http import HTTPCycle +from .lifespan import LifespanCycleState, LifespanCycle + +__all__ = [ + "HTTPCycle", + "LifespanCycleState", + "LifespanCycle", +] diff --git a/mangum/protocols/http.py b/mangum/protocols/http.py index 1cb5d413..8b36e9bc 100644 --- a/mangum/protocols/http.py +++ b/mangum/protocols/http.py @@ -1,33 +1,13 @@ -import base64 import enum import asyncio -import typing -import cgi +from typing import Optional import logging from io import BytesIO -from dataclasses import dataclass, field +from dataclasses import dataclass -from mangum.types import ASGIApp, Message, Scope -from mangum.exceptions import UnexpectedMessage - - -def all_casings(input_string: str) -> typing.Generator: - """ - Permute all casings of a given string. - A pretty algoritm, via @Amber - http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python - """ - if not input_string: - yield "" - else: - first = input_string[:1] - if first.lower() == first.upper(): - for sub_casing in all_casings(input_string[1:]): - yield first + sub_casing - else: - for sub_casing in all_casings(input_string[1:]): - yield first.lower() + sub_casing - yield first.upper() + sub_casing +from .. import Response, Request +from ..types import ASGIApp, Message +from ..exceptions import UnexpectedMessage class HTTPCycleState(enum.Enum): @@ -57,8 +37,6 @@ class HTTPCycle: * **scope** - A dictionary containing the connection scope used to run the ASGI application instance. - * **text_mime_types** - A list of mime types of MIME types that should not return - a binary response in API Gateway. * **state** - An enumerated `HTTPCycleState` type that indicates the state of the ASGI connection. * **app_queue** - An asyncio queue (FIFO) containing messages to be received by the @@ -66,19 +44,17 @@ class HTTPCycle: * **response** - A dictionary containing the response data to return in AWS Lambda. """ - scope: Scope - text_mime_types: typing.List[str] + scope: Request state: HTTPCycleState = HTTPCycleState.REQUEST - response: dict = field(default_factory=dict) + response: Optional[Response] = None def __post_init__(self) -> None: self.logger: logging.Logger = logging.getLogger("mangum.http") self.loop = asyncio.get_event_loop() self.app_queue: asyncio.Queue = asyncio.Queue() self.body: BytesIO = BytesIO() - self.response["isBase64Encoded"] = False - def __call__(self, app: ASGIApp, initial_body: bytes) -> dict: + def __call__(self, app: ASGIApp, initial_body: bytes) -> Response: self.logger.debug("HTTP cycle starting.") self.app_queue.put_nowait( {"type": "http.request", "body": initial_body, "more_body": False} @@ -87,6 +63,14 @@ def __call__(self, app: ASGIApp, initial_body: bytes) -> dict: asgi_task = self.loop.create_task(asgi_instance) self.loop.run_until_complete(asgi_task) + if self.response is None: + # Something really bad happened and we puked before we could get a + # response out + self.response = Response( + status=500, + body=b"Internal Server Error", + headers=[[b"content-type", b"text/plain; charset=utf-8"]], + ) return self.response async def run(self, app: ASGIApp) -> None: @@ -94,7 +78,7 @@ async def run(self, app: ASGIApp) -> None: Calls the application with the `http` connection scope. """ try: - await app(self.scope, self.receive, self.send) + await app(self.scope.as_dict(), self.receive, self.send) except BaseException as exc: self.logger.error("Exception in 'http' protocol.", exc_info=exc) if self.state is HTTPCycleState.REQUEST: @@ -102,22 +86,24 @@ async def run(self, app: ASGIApp) -> None: { "type": "http.response.start", "status": 500, - "headers": [(b"content-type", b"text/plain; charset=utf-8")], + "headers": [[b"content-type", b"text/plain; charset=utf-8"]], } ) await self.send( {"type": "http.response.body", "body": b"Internal Server Error"} ) elif self.state is not HTTPCycleState.COMPLETE: - self.response["statusCode"] = 500 - self.response["body"] = "Internal Server Error" - self.response["headers"] = {"content-type": "text/plain; charset=utf-8"} + self.response = Response( + status=500, + body=b"Internal Server Error", + headers=[[b"content-type", b"text/plain; charset=utf-8"]], + ) async def receive(self) -> Message: """ Awaited by the application to receive ASGI `http` events. """ - return await self.app_queue.get() + return await self.app_queue.get() # pragma: no cover async def send(self, message: Message) -> None: """ @@ -132,52 +118,13 @@ async def send(self, message: Message) -> None: self.state is HTTPCycleState.REQUEST and message_type == "http.response.start" ): - self.response["statusCode"] = message["status"] - headers: typing.Dict[str, str] = {} - multi_value_headers: typing.Dict[str, typing.List[str]] = {} - cookies: typing.List[str] = [] - event = self.scope["aws.event"] - # ELB - if "elb" in event["requestContext"]: - for key, value in message.get("headers", []): - lower_key = key.decode().lower() - if lower_key in multi_value_headers: - multi_value_headers[lower_key].append(value.decode()) - else: - multi_value_headers[lower_key] = [value.decode()] - if "multiValueHeaders" not in event: - # If there are multiple occurrences of headers, create case-mutated - # variations: https://github.com/logandk/serverless-wsgi/issues/11 - for key, values in multi_value_headers.items(): - if len(values) > 1: - for value, cased_key in zip(values, all_casings(key)): - headers[cased_key] = value - elif len(values) == 1: - headers[key] = values[0] - multi_value_headers = {} - # API Gateway - else: - for key, value in message.get("headers", []): - lower_key = key.decode().lower() - if event.get("version") == "2.0" and lower_key == "set-cookie": - cookies.append(value.decode()) - elif lower_key in multi_value_headers: - multi_value_headers[lower_key].append(value.decode()) - elif lower_key in headers: - multi_value_headers[lower_key] = [ - headers.pop(lower_key), - value.decode(), - ] - else: - headers[lower_key] = value.decode() - - self.response["headers"] = headers - if multi_value_headers: - self.response["multiValueHeaders"] = multi_value_headers - if cookies: - self.response["cookies"] = cookies + if self.response is None: + self.response = Response( + status=message["status"], + headers=message.get("headers", []), + body=b"", + ) self.state = HTTPCycleState.RESPONSE - elif ( self.state is HTTPCycleState.RESPONSE and message_type == "http.response.body" @@ -188,23 +135,11 @@ async def send(self, message: Message) -> None: # The body must be completely read before returning the response. self.body.write(body) - if not more_body: + if not more_body and self.response is not None: body = self.body.getvalue() self.body.close() + self.response.body = body - # Check if a binary response should be returned based on the mime type - # or content encoding. - mimetype, _ = cgi.parse_header( - self.response["headers"].get("content-type", "text/plain") - ) - if ( - mimetype not in self.text_mime_types - and not mimetype.startswith("text/") - ) or self.response["headers"].get("content-encoding") in ["gzip", "br"]: - body = base64.b64encode(body) - self.response["isBase64Encoded"] = True - - self.response["body"] = body.decode() self.state = HTTPCycleState.COMPLETE await self.app_queue.put({"type": "http.disconnect"}) diff --git a/mangum/protocols/lifespan.py b/mangum/protocols/lifespan.py index 6322ab24..790277bb 100644 --- a/mangum/protocols/lifespan.py +++ b/mangum/protocols/lifespan.py @@ -5,8 +5,8 @@ import enum from dataclasses import dataclass -from mangum.types import ASGIApp, Message -from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage +from ..types import ASGIApp, Message +from ..exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage class LifespanCycleState(enum.Enum): diff --git a/mangum/types.py b/mangum/types.py index 006c852b..03057b49 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -1,12 +1,74 @@ import typing +from dataclasses import dataclass, field +from typing import List, Tuple, Dict, Any, Union, Optional, TYPE_CHECKING + from typing_extensions import Protocol Message = typing.MutableMapping[str, typing.Any] -Scope = typing.MutableMapping[str, typing.Any] +ScopeDict = typing.MutableMapping[str, typing.Any] Receive = typing.Callable[[], typing.Awaitable[Message]] Send = typing.Callable[[Message], typing.Awaitable[None]] +if TYPE_CHECKING: # pragma: no cover + from awslambdaric.lambda_context import LambdaContext + class ASGIApp(Protocol): - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def __call__(self, scope: ScopeDict, receive: Receive, send: Send) -> None: ... # pragma: no cover + + +@dataclass +class Request: + """ + A holder for an ASGI scope. Contains additional meta from the event that triggered + the + + https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope + """ + + method: str + headers: List[List[bytes]] + path: str + scheme: str + query_string: bytes + server: Tuple[str, int] + client: Tuple[str, int] + + # Invocation event + trigger_event: Dict[str, Any] + trigger_context: Union["LambdaContext", Dict[str, Any]] + event_type: str + + type: str = "http" + http_version: str = "1.1" + raw_path: Optional[str] = None + root_path: str = "" + asgi: Dict[str, str] = field(default_factory=lambda: {"version": "3.0"}) + + def as_dict(self) -> ScopeDict: + return { + "type": self.type, + "http_version": self.http_version, + "method": self.method, + "headers": self.headers, + "path": self.path, + "raw_path": self.raw_path, + "root_path": self.root_path, + "scheme": self.scheme, + "query_string": self.query_string, + "server": self.server, + "client": self.client, + "asgi": self.asgi, + # Meta data to pass along to the application in case they need it + "aws.event": self.trigger_event, + "aws.context": self.trigger_context, + "aws.eventType": self.event_type, + } + + +@dataclass +class Response: + status: int + headers: List[List[bytes]] # ex: [[b'content-type', b'text/plain; charset=utf-8']] + body: bytes diff --git a/setup.cfg b/setup.cfg index ea102e04..78c5ea9c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,4 @@ [flake8] - max-line-length = 88 ignore = E203, W503 per-file-ignores = diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index eb88ed9a..29b32d30 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ @pytest.fixture -def mock_http_event(request): +def mock_aws_api_gateway_event(request): method = request.param[0] body = request.param[1] multi_value_query_parameters = request.param[2] @@ -119,62 +119,72 @@ def mock_http_api_event(request): @pytest.fixture -def mock_http_elb_singlevalue_event(request): +def mock_lambda_at_edge_event(request): method = request.param[0] - body = request.param[1] - multi_value_query_parameters = request.param[2] - event = { - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:0:targetgroup/test/0" - } - }, - "httpMethod": method, - "path": "/my/path", - "queryStringParameters": { - k: v[-1] for k, v in multi_value_query_parameters.items() - } - if multi_value_query_parameters - else None, - "headers": { - "accept-encoding": "gzip, deflate", - "cookie": "cookie1; cookie2", - "host": "test.execute-api.us-west-2.amazonaws.com", - "x-forwarded-for": "192.168.100.3, 192.168.100.2, 192.168.100.1", - "x-forwarded-port": "443", - "x-forwarded-proto": "https", - }, - "body": body, - "isBase64Encoded": False, - } - - return event + path = request.param[1] + query_string = request.param[2] + body = request.param[3] + headers_raw = { + "accept-encoding": "gzip,deflate", + "x-forwarded-port": "443", + "x-forwarded-for": "192.168.100.1", + "x-forwarded-proto": "https", + "host": "test.execute-api.us-west-2.amazonaws.com", + } + headers = {} + for key, value in headers_raw.items(): + headers[key.lower()] = [{"key": key, "value": value}] -@pytest.fixture -def mock_http_elb_multivalue_event(request): - method = request.param[0] - body = request.param[1] - multi_value_query_parameters = request.param[2] event = { - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:0:targetgroup/test/0" + "Records": [ + { + "cf": { + "config": { + "distributionDomainName": "mock-distribution.local.localhost", + "distributionId": "ABC123DEF456G", + "eventType": "origin-request", + "requestId": "lBEBo2N0JKYUP2JXwn_4am2xAXB2GzcL2FlwXI8G59PA8wghF2ImFQ==", + }, + "request": { + "clientIp": "192.168.100.1", + "headers": headers, + "method": method, + "origin": { + "custom": { + "customHeaders": { + "x-lae-env-custom-var": [ + { + "key": "x-lae-env-custom-var", + "value": "environment variable", + } + ], + }, + "domainName": "www.example.com", + "keepaliveTimeout": 5, + "path": "", + "port": 80, + "protocol": "http", + "readTimeout": 30, + "sslProtocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + } + }, + "querystring": query_string, + "uri": path, + }, + } } - }, - "httpMethod": method, - "path": "/my/path", - "multiValueQueryStringParameters": multi_value_query_parameters or None, - "multiValueHeaders": { - "accept-encoding": ["gzip, deflate"], - "cookie": ["cookie1; cookie2"], - "host": ["test.execute-api.us-west-2.amazonaws.com"], - "x-forwarded-for": ["192.168.100.3, 192.168.100.2, 192.168.100.1"], - "x-forwarded-port": ["443"], - "x-forwarded-proto": ["https"], - }, - "body": body, - "isBase64Encoded": False, + ] } - return event + if body is not None: + event["Records"][0]["cf"]["request"]["body"] = { + "inputTruncated": False, + "action": "read-only", + "encoding": "text", + "data": body, + } + + return dict( + method=method, path=path, query_string=query_string, body=body, event=event + ) diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/handlers/test_abstract_handler.py b/tests/handlers/test_abstract_handler.py new file mode 100644 index 00000000..fd2d4310 --- /dev/null +++ b/tests/handlers/test_abstract_handler.py @@ -0,0 +1,13 @@ +import pytest + +from mangum.handlers import AbstractHandler + + +def test_abstract_handler_unkown_event(): + """ + Test an unknown event, ensure it fails in a consistent way + """ + example_event = {"hello": "world", "foo": "bar"} + example_context = {} + with pytest.raises(TypeError): + AbstractHandler.from_trigger(example_event, example_context) diff --git a/tests/handlers/test_aws_alb.py b/tests/handlers/test_aws_alb.py new file mode 100644 index 00000000..36640f59 --- /dev/null +++ b/tests/handlers/test_aws_alb.py @@ -0,0 +1,285 @@ +import pytest + +from mangum import Mangum +from mangum.handlers import AwsAlb + + +def get_mock_aws_alb_event( + method, path, multi_value_query_parameters, body, body_base64_encoded +): + return { + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" # noqa: E501 + } + }, + "httpMethod": method, + "path": path, + "queryStringParameters": multi_value_query_parameters + if multi_value_query_parameters + else {}, + "headers": { + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9," + "image/webp,image/apng,*/*;q=0.8", + "accept-encoding": "gzip", + "accept-language": "en-US,en;q=0.9", + "connection": "keep-alive", + "host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com", + "upgrade-insecure-requests": "1", + "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/71.0.3578.98 Safari/537.36", + "x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476", + "x-forwarded-for": "72.12.164.125", + "x-forwarded-port": "80", + "x-forwarded-proto": "http", + "x-imforwards": "20", + }, + "body": body, + "isBase64Encoded": body_base64_encoded, + } + + +def test_aws_alb_basic(): + """ + Test the event from the AWS docs + """ + example_event = { + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" # noqa: E501 + } + }, + "httpMethod": "GET", + "path": "/lambda", + "queryStringParameters": {"query": "1234ABCD"}, + "headers": { + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9," + "image/webp,image/apng,*/*;q=0.8", + "accept-encoding": "gzip", + "accept-language": "en-US,en;q=0.9", + "connection": "keep-alive", + "host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com", + "upgrade-insecure-requests": "1", + "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", # noqa: E501 + "x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476", + "x-forwarded-for": "72.12.164.125", + "x-forwarded-port": "80", + "x-forwarded-proto": "http", + "x-imforwards": "20", + }, + "body": "", + "isBase64Encoded": False, + } + + example_context = {} + handler = AwsAlb(example_event, example_context) + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": example_event, + "aws.eventType": "AWS_ALB", + "client": ("72.12.164.125", 0), + "headers": [ + [ + b"accept", + b"text/html,application/xhtml+xml,application/xml;q=0.9,image/" + b"webp,image/apng,*/*;q=0.8", + ], + [b"accept-encoding", b"gzip"], + [b"accept-language", b"en-US,en;q=0.9"], + [b"connection", b"keep-alive"], + [b"host", b"lambda-alb-123578498.us-east-2.elb.amazonaws.com"], + [b"upgrade-insecure-requests", b"1"], + [ + b"user-agent", + b"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" + b" (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", + ], + [b"x-amzn-trace-id", b"Root=1-5c536348-3d683b8b04734faae651f476"], + [b"x-forwarded-for", b"72.12.164.125"], + [b"x-forwarded-port", b"80"], + [b"x-forwarded-proto", b"http"], + [b"x-imforwards", b"20"], + ], + "http_version": "1.1", + "method": "GET", + "path": "/lambda", + "query_string": b"query=1234ABCD", + "raw_path": None, + "root_path": "", + "scheme": "http", + "server": ("lambda-alb-123578498.us-east-2.elb.amazonaws.com", 80), + "type": "http", + } + + +@pytest.mark.parametrize( + "method,path,multi_value_query_parameters,req_body,body_base64_encoded," + "query_string,scope_body", + [ + ("GET", "/hello/world", None, None, False, b"", None), + ("POST", "/", {"name": ["me"]}, None, False, b"name=me", None), + ( + "GET", + "/my/resource", + {"name": ["me", "you"]}, + None, + False, + b"name=me&name=you", + None, + ), + ( + "GET", + "", + {"name": ["me", "you"], "pet": ["dog"]}, + None, + False, + b"name=me&name=you&pet=dog", + None, + ), + # A 1x1 red px gif + ( + "POST", + "/img", + None, + b"R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + b"", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + ), + ("POST", "/form-submit", None, b"say=Hi&to=Mom", False, b"", b"say=Hi&to=Mom"), + ], +) +def test_aws_alb_scope_real( + method, + path, + multi_value_query_parameters, + req_body, + body_base64_encoded, + query_string, + scope_body, +): + event = get_mock_aws_alb_event( + method, path, multi_value_query_parameters, req_body, body_base64_encoded + ) + example_context = {} + handler = AwsAlb(event, example_context) + + scope_path = path + if scope_path == "": + scope_path = "/" + + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": event, + "aws.eventType": "AWS_ALB", + "client": ("72.12.164.125", 0), + "headers": [ + [ + b"accept", + b"text/html,application/xhtml+xml,application/xml;q=0.9,image/" + b"webp,image/apng,*/*;q=0.8", + ], + [b"accept-encoding", b"gzip"], + [b"accept-language", b"en-US,en;q=0.9"], + [b"connection", b"keep-alive"], + [b"host", b"lambda-alb-123578498.us-east-2.elb.amazonaws.com"], + [b"upgrade-insecure-requests", b"1"], + [ + b"user-agent", + b"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" + b" (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", + ], + [b"x-amzn-trace-id", b"Root=1-5c536348-3d683b8b04734faae651f476"], + [b"x-forwarded-for", b"72.12.164.125"], + [b"x-forwarded-port", b"80"], + [b"x-forwarded-proto", b"http"], + [b"x-imforwards", b"20"], + ], + "http_version": "1.1", + "method": method, + "path": scope_path, + "query_string": query_string, + "raw_path": None, + "root_path": "", + "scheme": "http", + "server": ("lambda-alb-123578498.us-east-2.elb.amazonaws.com", 80), + "type": "http", + } + + assert handler.body == scope_body + + +def test_aws_alb_set_cookies() -> None: + async def app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain; charset=utf-8"], + [b"set-cookie", b"cookie1=cookie1; Secure"], + [b"set-cookie", b"cookie2=cookie2; Secure"], + ], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + handler = Mangum(app, lifespan="off") + event = get_mock_aws_alb_event("GET", "/test", {}, None, False) + response = handler(event, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": { + "set-cookie": ["cookie1=cookie1; Secure", "cookie2=cookie2; Secure"], + }, + "body": "Hello, world!", + } + + +@pytest.mark.parametrize( + "method,content_type,raw_res_body,res_body,res_base64_encoded", + [ + ("GET", b"text/plain; charset=utf-8", b"Hello world", "Hello world", False), + # A 1x1 red px gif + ( + "POST", + b"image/gif", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + "R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + ), + ], +) +def test_aws_alb_response( + method, content_type, raw_res_body, res_body, res_base64_encoded +): + async def app(scope, receive, send): + assert scope["aws.eventType"] == "AWS_ALB" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", content_type]], + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = get_mock_aws_alb_event(method, "/test", {}, None, False) + + handler = Mangum(app, lifespan="off") + + response = handler(event, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": res_base64_encoded, + "headers": {"content-type": content_type.decode()}, + "multiValueHeaders": {}, + "body": res_body, + } diff --git a/tests/handlers/test_aws_api_gateway.py b/tests/handlers/test_aws_api_gateway.py new file mode 100644 index 00000000..f7b87a37 --- /dev/null +++ b/tests/handlers/test_aws_api_gateway.py @@ -0,0 +1,333 @@ +import pytest + +import urllib.parse + +from mangum import Mangum +from mangum.handlers import AwsApiGateway + + +def get_mock_aws_api_gateway_event( + method, path, multi_value_query_parameters, body, body_base64_encoded +): + return { + "path": path, + "body": body, + "isBase64Encoded": body_base64_encoded, + "headers": { + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9," + "image/webp,*/*;q=0.8", + "Accept-Encoding": "gzip, deflate, lzma, sdch, br", + "Accept-Language": "en-US,en;q=0.8", + "CloudFront-Forwarded-Proto": "https", + "CloudFront-Is-Desktop-Viewer": "true", + "CloudFront-Is-Mobile-Viewer": "false", + "CloudFront-Is-SmartTV-Viewer": "false", + "CloudFront-Is-Tablet-Viewer": "false", + "CloudFront-Viewer-Country": "US", + "Cookie": "cookie1; cookie2", + "Host": "test.execute-api.us-west-2.amazonaws.com", + "Upgrade-Insecure-Requests": "1", + "X-Forwarded-For": "192.168.100.1, 192.168.1.1", + "X-Forwarded-Port": "443", + "X-Forwarded-Proto": "https", + }, + "pathParameters": {"proxy": "hello"}, + "requestContext": { + "accountId": "123456789012", + "resourceId": "us4z18", + "stage": "Prod", + "requestId": "41b45ea3-70b5-11e6-b7bd-69b5aaebc7d9", + "identity": { + "cognitoIdentityPoolId": "", + "accountId": "", + "cognitoIdentityId": "", + "caller": "", + "apiKey": "", + "sourceIp": "192.168.100.1", + "cognitoAuthenticationType": "", + "cognitoAuthenticationProvider": "", + "userArn": "", + "userAgent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/52.0.2743.82 Safari/537.36 OPR/39.0.2256.48", # noqa: E501 + "user": "", + }, + "resourcePath": "/{proxy+}", + "httpMethod": method, + "apiId": "123", + }, + "resource": "/{proxy+}", + "httpMethod": method, + "multiValueQueryStringParameters": { + k: v for k, v in multi_value_query_parameters.items() + } + if multi_value_query_parameters + else None, + "stageVariables": {"stageVarName": "stageVarValue"}, + } + + +def test_aws_api_gateway_scope_basic(): + """ + Test the event from the AWS docs + """ + example_event = { + "resource": "/", + "path": "/", + "httpMethod": "GET", + "requestContext": {"resourcePath": "/", "httpMethod": "GET", "path": "/Prod/"}, + "headers": { + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", # noqa: E501 + "accept-encoding": "gzip, deflate, br", + "Host": "70ixmpl4fl.execute-api.us-east-2.amazonaws.com", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.132 Safari/537.36", # noqa: E501 + "X-Amzn-Trace-Id": "Root=1-5e66d96f-7491f09xmpl79d18acf3d050", + }, + "multiValueHeaders": { + "accept": [ + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" # noqa: E501 + ], + "accept-encoding": ["gzip, deflate, br"], + }, + "queryStringParameters": {"foo": "bar"}, + "multiValueQueryStringParameters": None, + "pathParameters": None, + "stageVariables": None, + "body": None, + "isBase64Encoded": False, + } + example_context = {} + handler = AwsApiGateway(example_event, example_context) + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": example_event, + "aws.eventType": "AWS_API_GATEWAY", + "client": (None, 0), + "headers": [ + [ + b"accept", + b"text/html,application/xhtml+xml,application/xml;q=0.9," + b"image/webp,image/apng,*/*;q=0.8," + b"application/signed-exchange;v=b3;q=0.9", + ], + [b"accept-encoding", b"gzip, deflate, br"], + ], + "http_version": "1.1", + "method": "GET", + "path": "/", + "query_string": b"foo=bar", + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("mangum", 80), + "type": "http", + } + + +@pytest.mark.parametrize( + "method,path,multi_value_query_parameters,req_body,body_base64_encoded," + "query_string,scope_body", + [ + ("GET", "/hello/world", None, None, False, b"", None), + ("POST", "/", {"name": ["me"]}, None, False, b"name=me", None), + ( + "GET", + "/my/resource", + {"name": ["me", "you"]}, + None, + False, + b"name=me&name=you", + None, + ), + ( + "GET", + "", + {"name": ["me", "you"], "pet": ["dog"]}, + None, + False, + b"name=me&name=you&pet=dog", + None, + ), + # A 1x1 red px gif + ( + "POST", + "/img", + None, + b"R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + b"", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + ), + ("POST", "/form-submit", None, b"say=Hi&to=Mom", False, b"", b"say=Hi&to=Mom"), + ], +) +def test_aws_api_gateway_scope_real( + method, + path, + multi_value_query_parameters, + req_body, + body_base64_encoded, + query_string, + scope_body, +): + event = get_mock_aws_api_gateway_event( + method, path, multi_value_query_parameters, req_body, body_base64_encoded + ) + example_context = {} + handler = AwsApiGateway(event, example_context) + + scope_path = path + if scope_path == "": + scope_path = "/" + + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": event, + "aws.eventType": "AWS_API_GATEWAY", + "client": ("192.168.100.1", 0), + "headers": [ + [ + b"accept", + b"text/html,application/xhtml+xml,application/xml;q=0.9,image/" + b"webp,*/*;q=0.8", + ], + [b"accept-encoding", b"gzip, deflate, lzma, sdch, br"], + [b"accept-language", b"en-US,en;q=0.8"], + [b"cloudfront-forwarded-proto", b"https"], + [b"cloudfront-is-desktop-viewer", b"true"], + [b"cloudfront-is-mobile-viewer", b"false"], + [b"cloudfront-is-smarttv-viewer", b"false"], + [b"cloudfront-is-tablet-viewer", b"false"], + [b"cloudfront-viewer-country", b"US"], + [b"cookie", b"cookie1; cookie2"], + [b"host", b"test.execute-api.us-west-2.amazonaws.com"], + [b"upgrade-insecure-requests", b"1"], + [b"x-forwarded-for", b"192.168.100.1, 192.168.1.1"], + [b"x-forwarded-port", b"443"], + [b"x-forwarded-proto", b"https"], + ], + "http_version": "1.1", + "method": method, + "path": scope_path, + "query_string": query_string, + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("test.execute-api.us-west-2.amazonaws.com", 443), + "type": "http", + } + + assert handler.body == scope_body + + +@pytest.mark.parametrize( + "method,path,multi_value_query_parameters,req_body,body_base64_encoded," + "query_string,scope_body", + [ + ("GET", "/test/hello", None, None, False, b"", None), + ], +) +def test_aws_api_gateway_base_path( + method, + path, + multi_value_query_parameters, + req_body, + body_base64_encoded, + query_string, + scope_body, +): + event = get_mock_aws_api_gateway_event( + method, path, multi_value_query_parameters, req_body, body_base64_encoded + ) + + async def app(scope, receive, send): + assert scope["type"] == "http" + assert scope["path"] == urllib.parse.unquote(event["path"]) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello world!"}) + + handler = Mangum(app, lifespan="off", base_path=None) + response = handler(event, {}) + + assert response == { + "body": "Hello world!", + "headers": {"content-type": "text/plain"}, + "multiValueHeaders": {}, + "isBase64Encoded": False, + "statusCode": 200, + } + + async def app(scope, receive, send): + assert scope["type"] == "http" + assert scope["path"] == urllib.parse.unquote( + event["path"][len(f"/{api_gateway_base_path}") :] + ) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello world!"}) + + api_gateway_base_path = "test" + handler = Mangum(app, lifespan="off", base_path=api_gateway_base_path) + response = handler(event, {}) + assert response == { + "body": "Hello world!", + "headers": {"content-type": "text/plain"}, + "multiValueHeaders": {}, + "isBase64Encoded": False, + "statusCode": 200, + } + + +@pytest.mark.parametrize( + "method,content_type,raw_res_body,res_body,res_base64_encoded", + [ + ("GET", b"text/plain; charset=utf-8", b"Hello world", "Hello world", False), + # A 1x1 red px gif + ( + "POST", + b"image/gif", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + "R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + ), + ], +) +def test_aws_api_gateway_response( + method, content_type, raw_res_body, res_body, res_base64_encoded +): + async def app(scope, receive, send): + assert scope["aws.eventType"] == "AWS_API_GATEWAY" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", content_type]], + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = get_mock_aws_api_gateway_event(method, "/test", {}, None, False) + + handler = Mangum(app, lifespan="off") + + response = handler(event, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": res_base64_encoded, + "headers": {"content-type": content_type.decode()}, + "multiValueHeaders": {}, + "body": res_body, + } diff --git a/tests/handlers/test_aws_cf_lambda_at_edge.py b/tests/handlers/test_aws_cf_lambda_at_edge.py new file mode 100644 index 00000000..df349e93 --- /dev/null +++ b/tests/handlers/test_aws_cf_lambda_at_edge.py @@ -0,0 +1,288 @@ +import urllib.parse + +import pytest + +from mangum import Mangum +from mangum.handlers import AwsCfLambdaAtEdge + + +def mock_lambda_at_edge_event( + method, path, multi_value_query_parameters, body, body_base64_encoded +): + headers_raw = { + "accept-encoding": "gzip,deflate", + "x-forwarded-port": "443", + "x-forwarded-for": "192.168.100.1", + "x-forwarded-proto": "https", + "host": "test.execute-api.us-west-2.amazonaws.com", + } + headers = {} + for key, value in headers_raw.items(): + headers[key.lower()] = [{"key": key, "value": value}] + + event = { + "Records": [ + { + "cf": { + "config": { + "distributionDomainName": "mock-distribution.local.localhost", + "distributionId": "ABC123DEF456G", + "eventType": "origin-request", + "requestId": "lBEBo2N0JKYUP2JXwn_4am2xAXB2GzcL2FlwXI8G59PA8wghF2ImFQ==", # noqa: E501 + }, + "request": { + "clientIp": "192.168.100.1", + "headers": headers, + "method": method, + "origin": { + "custom": { + "customHeaders": { + "x-lae-env-custom-var": [ + { + "key": "x-lae-env-custom-var", + "value": "environment variable", + } + ], + }, + "domainName": "www.example.com", + "keepaliveTimeout": 5, + "path": "", + "port": 80, + "protocol": "http", + "readTimeout": 30, + "sslProtocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + } + }, + "querystring": urllib.parse.urlencode( + multi_value_query_parameters + if multi_value_query_parameters + else {}, + doseq=True, + ), + "uri": path, + }, + } + } + ] + } + + if body is not None: + event["Records"][0]["cf"]["request"]["body"] = { + "inputTruncated": False, + "action": "read-only", + "encoding": "base64" if body_base64_encoded else "text", + "data": body, + } + return event + + +def test_aws_cf_lambda_at_edge_scope_basic(): + """ + Test the event from the AWS docs + """ + example_event = { + "Records": [ + { + "cf": { + "config": { + "distributionDomainName": "d111111abcdef8.cloudfront.net", + "distributionId": "EDFDVBD6EXAMPLE", + "eventType": "origin-request", + "requestId": "4TyzHTaYWb1GX1qTfsHhEqV6HUDd_BzoBZnwfnvQc_1oF26ClkoUSEQ==", # noqa: E501 + }, + "request": { + "clientIp": "203.0.113.178", + "headers": { + "x-forwarded-for": [ + {"key": "X-Forwarded-For", "value": "203.0.113.178"} + ], + "user-agent": [ + {"key": "User-Agent", "value": "Amazon CloudFront"} + ], + "via": [ + { + "key": "Via", + "value": "2.0 2afae0d44e2540f472c0635ab62c232b.cloudfront.net (CloudFront)", # noqa: E501 + } + ], + "host": [{"key": "Host", "value": "example.org"}], + "cache-control": [ + { + "key": "Cache-Control", + "value": "no-cache, cf-no-cache", + } + ], + }, + "method": "GET", + "origin": { + "custom": { + "customHeaders": {}, + "domainName": "example.org", + "keepaliveTimeout": 5, + "path": "", + "port": 443, + "protocol": "https", + "readTimeout": 30, + "sslProtocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + } + }, + "querystring": "", + "uri": "/", + }, + } + } + ] + } + example_context = {} + handler = AwsCfLambdaAtEdge(example_event, example_context) + + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": example_event, + "aws.eventType": "AWS_CF_LAMBDA_AT_EDGE", + "client": ("203.0.113.178", 0), + "headers": [ + [b"x-forwarded-for", b"203.0.113.178"], + [b"user-agent", b"Amazon CloudFront"], + [ + b"via", + b"2.0 2afae0d44e2540f472c0635ab62c232b.cloudfront.net (CloudFront)", + ], + [b"host", b"example.org"], + [b"cache-control", b"no-cache, cf-no-cache"], + ], + "http_version": "1.1", + "method": "GET", + "path": "/", + "query_string": b"", + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("example.org", 80), + "type": "http", + } + + +@pytest.mark.parametrize( + "method,path,multi_value_query_parameters,req_body," + "body_base64_encoded,query_string,scope_body", + [ + ("GET", "/hello/world", None, None, False, b"", None), + ("POST", "/", {"name": ["me"]}, None, False, b"name=me", None), + ( + "GET", + "/my/resource", + {"name": ["me", "you"]}, + None, + False, + b"name=me&name=you", + None, + ), + ( + "GET", + "", + {"name": ["me", "you"], "pet": ["dog"]}, + None, + False, + b"name=me&name=you&pet=dog", + None, + ), + # A 1x1 red px gif + ( + "POST", + "/img", + None, + b"R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + b"", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + ), + ("POST", "/form-submit", None, b"say=Hi&to=Mom", False, b"", b"say=Hi&to=Mom"), + ], +) +def test_aws_api_gateway_scope_real( + method, + path, + multi_value_query_parameters, + req_body, + body_base64_encoded, + query_string, + scope_body, +): + event = mock_lambda_at_edge_event( + method, path, multi_value_query_parameters, req_body, body_base64_encoded + ) + example_context = {} + handler = AwsCfLambdaAtEdge(event, example_context) + + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": event, + "aws.eventType": "AWS_CF_LAMBDA_AT_EDGE", + "client": ("192.168.100.1", 0), + "headers": [ + [b"accept-encoding", b"gzip,deflate"], + [b"x-forwarded-port", b"443"], + [b"x-forwarded-for", b"192.168.100.1"], + [b"x-forwarded-proto", b"https"], + [b"host", b"test.execute-api.us-west-2.amazonaws.com"], + ], + "http_version": "1.1", + "method": method, + "path": path, + "query_string": query_string, + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("test.execute-api.us-west-2.amazonaws.com", 443), + "type": "http", + } + + assert handler.body == scope_body + + +@pytest.mark.parametrize( + "method,content_type,raw_res_body,res_body,res_base64_encoded", + [ + ("GET", b"text/plain; charset=utf-8", b"Hello world", "Hello world", False), + # A 1x1 red px gif + ( + "POST", + b"image/gif", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + "R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + ), + ], +) +def test_aws_lambda_at_edge_response( + method, content_type, raw_res_body, res_body, res_base64_encoded +): + async def app(scope, receive, send): + assert scope["aws.eventType"] == "AWS_CF_LAMBDA_AT_EDGE" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", content_type]], + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = mock_lambda_at_edge_event(method, "/test", {}, None, False) + + handler = Mangum(app, lifespan="off") + + response = handler(event, {}) + assert response == { + "status": 200, + "isBase64Encoded": res_base64_encoded, + "headers": { + "content-type": [{"key": "content-type", "value": content_type.decode()}] + }, + "body": res_body, + } diff --git a/tests/handlers/test_aws_http_gateway.py b/tests/handlers/test_aws_http_gateway.py new file mode 100644 index 00000000..2deac9f5 --- /dev/null +++ b/tests/handlers/test_aws_http_gateway.py @@ -0,0 +1,600 @@ +import urllib.parse + +import pytest + +from mangum import Mangum +from mangum.handlers import AwsHttpGateway + + +def get_mock_aws_http_gateway_event_v1( + method, path, query_parameters, body, body_base64_encoded +): + query_string = urllib.parse.urlencode(query_parameters if query_parameters else {}) + return { + "version": "1.0", + "resource": path, + "path": path, + "httpMethod": method, + "headers": { + "accept-encoding": "gzip,deflate", + "x-forwarded-port": "443", + "x-forwarded-proto": "https", + "host": "test.execute-api.us-west-2.amazonaws.com", + }, + "multiValueHeaders": { + "accept-encoding": ["gzip", "deflate"], + "x-forwarded-port": ["443"], + "x-forwarded-proto": ["https"], + "host": ["test.execute-api.us-west-2.amazonaws.com"], + }, + "queryStringParameters": {k: v[0] for k, v in query_parameters.items()} + if query_parameters + else {}, + "multiValueQueryStringParameters": {k: v for k, v in query_parameters.items()} + if query_parameters + else {}, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": {"claims": None, "scopes": None}, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": method, + "identity": { + "accessKey": None, + "accountId": None, + "caller": None, + "cognitoAuthenticationProvider": None, + "cognitoAuthenticationType": None, + "cognitoIdentityId": None, + "cognitoIdentityPoolId": None, + "principalOrgId": None, + "sourceIp": "192.168.100.1", + "user": None, + "userAgent": "user-agent", + "userArn": None, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT", + }, + }, + }, + "path": path, + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": None, + "resourcePath": path, + "stage": "$default", + }, + "pathParameters": query_string, + "stageVariables": None, + "body": body, + "isBase64Encoded": body_base64_encoded, + } + + +def get_mock_aws_http_gateway_event_v2( + method, path, query_parameters, body, body_base64_encoded +): + query_string = urllib.parse.urlencode(query_parameters if query_parameters else {}) + return { + "version": "2.0", + "routeKey": "$default", + "rawPath": path, + "rawQueryString": query_string, + "cookies": ["cookie1", "cookie2"], + "headers": { + "accept-encoding": "gzip,deflate", + "x-forwarded-port": "443", + "x-forwarded-proto": "https", + "host": "test.execute-api.us-west-2.amazonaws.com", + }, + "queryStringParameters": {k: v[0] for k, v in query_parameters.items()} + if query_parameters + else {}, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authorizer": { + "jwt": { + "claims": {"claim1": "value1", "claim2": "value2"}, + "scopes": ["scope1", "scope2"], + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": method, + "path": path, + "protocol": "HTTP/1.1", + "sourceIp": "192.168.100.1", + "userAgent": "agent", + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390, + }, + "body": body, + "pathParameters": {"parameter1": "value1"}, + "isBase64Encoded": body_base64_encoded, + "stageVariables": {"stageVariable1": "value1", "stageVariable2": "value2"}, + } + + +def test_aws_http_gateway_scope_basic_v1(): + """ + Test the event from the AWS docs + """ + example_event = { + "version": "1.0", + "resource": "/my/path", + "path": "/my/path", + "httpMethod": "GET", + "headers": {"Header1": "value1", "Header2": "value2"}, + "multiValueHeaders": {"Header1": ["value1"], "Header2": ["value1", "value2"]}, + "queryStringParameters": {"parameter1": "value1", "parameter2": "value"}, + "multiValueQueryStringParameters": { + "parameter1": ["value1", "value2"], + "parameter2": ["value"], + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": {"claims": None, "scopes": None}, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": None, + "accountId": None, + "caller": None, + "cognitoAuthenticationProvider": None, + "cognitoAuthenticationType": None, + "cognitoIdentityId": None, + "cognitoIdentityPoolId": None, + "principalOrgId": None, + "sourceIp": "IP", + "user": None, + "userAgent": "user-agent", + "userArn": None, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT", + }, + }, + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": None, + "resourcePath": "/my/path", + "stage": "$default", + }, + "pathParameters": None, + "stageVariables": None, + "body": "Hello from Lambda!", + "isBase64Encoded": False, + } + example_context = {} + handler = AwsHttpGateway(example_event, example_context) + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": example_event, + "aws.eventType": "AWS_HTTP_GATEWAY", + "client": ("IP", 0), + "headers": [[b"header1", b"value1"], [b"header2", b"value1, value2"]], + "http_version": "1.1", + "method": "GET", + "path": "/my/path", + "query_string": b"parameter1=value1¶meter1=value2¶meter2=value", + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("mangum", 80), + "type": "http", + } + + +def test_aws_http_gateway_scope_v1_only_non_multi_headers(): + """ + Ensure only queryStringParameters headers still works (unsure if this is possible + from HTTP Gateway) + """ + example_event = get_mock_aws_http_gateway_event_v1( + "GET", "/test", {"hello": ["world", "life"]}, None, False + ) + del example_event["multiValueQueryStringParameters"] + example_context = {} + handler = AwsHttpGateway(example_event, example_context) + assert handler.scope.as_dict()["query_string"] == b"hello=world" + + +def test_aws_http_gateway_scope_v1_no_headers(): + """ + Ensure no headers still works (unsure if this is possible from HTTP Gateway) + """ + example_event = get_mock_aws_http_gateway_event_v1( + "GET", "/test", {"hello": ["world", "life"]}, None, False + ) + del example_event["multiValueQueryStringParameters"] + del example_event["queryStringParameters"] + example_context = {} + handler = AwsHttpGateway(example_event, example_context) + assert handler.scope.as_dict()["query_string"] == b"" + + +def test_aws_http_gateway_scope_basic_v2(): + """ + Test the event from the AWS docs + """ + example_event = { + "version": "2.0", + "routeKey": "$default", + "rawPath": "/my/path", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": ["cookie1", "cookie2"], + "headers": {"Header1": "value1", "Header2": "value1,value2"}, + "queryStringParameters": {"parameter1": "value1,value2", "parameter2": "value"}, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authentication": { + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT", + }, + } + }, + "authorizer": { + "jwt": { + "claims": {"claim1": "value1", "claim2": "value2"}, + "scopes": ["scope1", "scope2"], + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "POST", + "path": "/my/path", + "protocol": "HTTP/1.1", + "sourceIp": "IP", + "userAgent": "agent", + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390, + }, + "body": "Hello from Lambda", + "pathParameters": {"parameter1": "value1"}, + "isBase64Encoded": False, + "stageVariables": {"stageVariable1": "value1", "stageVariable2": "value2"}, + } + example_context = {} + handler = AwsHttpGateway(example_event, example_context) + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": example_event, + "aws.eventType": "AWS_HTTP_GATEWAY", + "client": ("IP", 0), + "headers": [ + [b"header1", b"value1"], + [b"header2", b"value1,value2"], + [b"cookie", b"cookie1; cookie2"], + ], + "http_version": "1.1", + "method": "POST", + "path": "/my/path", + "query_string": b"parameter1=value1¶meter1=value2¶meter2=value", + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("mangum", 80), + "type": "http", + } + + +def test_aws_http_gateway_scope_bad_version(): + """ + Set a version we don't support + + Version is the only thing that is different here, we should be checking that + specifically + """ + example_event = get_mock_aws_http_gateway_event_v2("GET", "/test", {}, None, False) + example_event["version"] = "9001.1" + example_context = {} + handler = AwsHttpGateway(example_event, example_context) + with pytest.raises(RuntimeError): + handler.scope.as_dict() + + +@pytest.mark.parametrize( + "method,path,query_parameters,req_body,body_base64_encoded,query_string,scope_body", + [ + ("GET", "/my/test/path", None, None, False, b"", None), + ("GET", "", {"name": "me"}, None, False, b"name=me", None), + # A 1x1 red px gif + ( + "POST", + "/img", + None, + b"R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + b"", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + ), + ("POST", "/form-submit", None, b"say=Hi&to=Mom", False, b"", b"say=Hi&to=Mom"), + ], +) +def test_aws_http_gateway_scope_real_v1( + method, + path, + query_parameters, + req_body, + body_base64_encoded, + query_string, + scope_body, +) -> None: + event = get_mock_aws_http_gateway_event_v1( + method, path, query_parameters, req_body, body_base64_encoded + ) + example_context = {} + handler = AwsHttpGateway(event, example_context) + + scope_path = path + if scope_path == "": + scope_path = "/" + + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": event, + "aws.eventType": "AWS_HTTP_GATEWAY", + "client": ("192.168.100.1", 0), + "headers": [ + [b"accept-encoding", b"gzip, deflate"], + [b"x-forwarded-port", b"443"], + [b"x-forwarded-proto", b"https"], + [b"host", b"test.execute-api.us-west-2.amazonaws.com"], + ], + "http_version": "1.1", + "method": method, + "path": scope_path, + "query_string": query_string, + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("test.execute-api.us-west-2.amazonaws.com", 443), + "type": "http", + } + + assert handler.body == scope_body + + +@pytest.mark.parametrize( + "method,path,query_parameters,req_body,body_base64_encoded,query_string,scope_body", + [ + ("GET", "/my/test/path", None, None, False, b"", None), + ("GET", "", {"name": "me"}, None, False, b"name=me", None), + # A 1x1 red px gif + ( + "POST", + "/img", + None, + b"R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + b"", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + ), + ("POST", "/form-submit", None, b"say=Hi&to=Mom", False, b"", b"say=Hi&to=Mom"), + ], +) +def test_aws_http_gateway_scope_real_v2( + method, + path, + query_parameters, + req_body, + body_base64_encoded, + query_string, + scope_body, +) -> None: + event = get_mock_aws_http_gateway_event_v2( + method, path, query_parameters, req_body, body_base64_encoded + ) + example_context = {} + handler = AwsHttpGateway(event, example_context) + + scope_path = path + if scope_path == "": + scope_path = "/" + + assert handler.scope.as_dict() == { + "asgi": {"version": "3.0"}, + "aws.context": {}, + "aws.event": event, + "aws.eventType": "AWS_HTTP_GATEWAY", + "client": ("192.168.100.1", 0), + "headers": [ + [b"accept-encoding", b"gzip,deflate"], + [b"x-forwarded-port", b"443"], + [b"x-forwarded-proto", b"https"], + [b"host", b"test.execute-api.us-west-2.amazonaws.com"], + [b"cookie", b"cookie1; cookie2"], + ], + "http_version": "1.1", + "method": method, + "path": scope_path, + "query_string": query_string, + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("test.execute-api.us-west-2.amazonaws.com", 443), + "type": "http", + } + + assert handler.body == scope_body + + +@pytest.mark.parametrize( + "method,content_type,raw_res_body,res_body,res_base64_encoded", + [ + ("GET", b"text/plain; charset=utf-8", b"Hello world", "Hello world", False), + ( + "GET", + b"application/json", + b'{"hello": "world", "foo": true}', + '{"hello": "world", "foo": true}', + False, + ), + ("GET", None, b"Hello world", "SGVsbG8gd29ybGQ=", True), + ( + "GET", + None, + b'{"hello": "world", "foo": true}', + "eyJoZWxsbyI6ICJ3b3JsZCIsICJmb28iOiB0cnVlfQ==", + True, + ), + # A 1x1 red px gif + ( + "POST", + b"image/gif", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + "R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + ), + ], +) +def test_aws_http_gateway_response_v1( + method, content_type, raw_res_body, res_body, res_base64_encoded +): + """ + Test response types make sense. v1 does less magic than v2. + """ + + async def app(scope, receive, send): + assert scope["aws.eventType"] == "AWS_HTTP_GATEWAY" + headers = [] + if content_type is not None: + headers.append([b"content-type", content_type]) + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": headers, + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = get_mock_aws_http_gateway_event_v1(method, "/test", {}, None, False) + + handler = Mangum(app, lifespan="off") + + response = handler(event, {}) + + res_headers = {} + if content_type is not None: + res_headers = {"content-type": content_type.decode()} + + assert response == { + "statusCode": 200, + "isBase64Encoded": res_base64_encoded, + "headers": res_headers, + "multiValueHeaders": {}, + "body": res_body, + } + + +@pytest.mark.parametrize( + "method,content_type,raw_res_body,res_body,res_base64_encoded", + [ + ("GET", b"text/plain; charset=utf-8", b"Hello world", "Hello world", False), + ( + "GET", + b"application/json", + b'{"hello": "world", "foo": true}', + '{"hello": "world", "foo": true}', + False, + ), + ("GET", None, b"Hello world", "Hello world", False), + ( + "GET", + None, + b'{"hello": "world", "foo": true}', + '{"hello": "world", "foo": true}', + False, + ), + # A 1x1 red px gif + ( + "POST", + b"image/gif", + b"GIF87a\x01\x00\x01\x00\x80\x01\x00\xff\x00\x00\x00\x00\x00," + b"\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;", + "R0lGODdhAQABAIABAP8AAAAAACwAAAAAAQABAAACAkQBADs=", + True, + ), + ], +) +def test_aws_http_gateway_response_v2( + method, content_type, raw_res_body, res_body, res_base64_encoded +): + async def app(scope, receive, send): + assert scope["aws.eventType"] == "AWS_HTTP_GATEWAY" + headers = [] + if content_type is not None: + headers.append([b"content-type", content_type]) + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": headers, + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = get_mock_aws_http_gateway_event_v2(method, "/test", {}, None, False) + + handler = Mangum(app, lifespan="off") + + response = handler(event, {}) + + if content_type is None: + content_type = b"application/json" + assert response == { + "statusCode": 200, + "isBase64Encoded": res_base64_encoded, + "headers": {"content-type": content_type.decode()}, + "multiValueHeaders": {}, + "body": res_body, + } diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 7d9bfec5..c2cffa1a 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,7 +1,6 @@ import pytest from mangum.exceptions import ConfigurationError -from mangum.adapter import DEFAULT_TEXT_MIME_TYPES from mangum import Mangum @@ -12,9 +11,6 @@ async def app(scope, receive, send): def test_default_settings(): handler = Mangum(app) assert handler.lifespan == "auto" - assert handler.log_level == "info" - assert handler.text_mime_types == DEFAULT_TEXT_MIME_TYPES - assert handler.api_gateway_base_path is None @pytest.mark.parametrize( @@ -24,11 +20,6 @@ def test_default_settings(): {"lifespan": "unknown"}, "Invalid argument supplied for `lifespan`. Choices are: auto|on|off", ), - ( - {"log_level": "unknown"}, - "Invalid argument supplied for `log_level`. Choices are: " - "critical|error|warning|info|debug", - ), ], ) def test_invalid_options(arguments, message): diff --git a/tests/test_http.py b/tests/test_http.py index ad0b4eed..c82efd8c 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,7 +1,6 @@ import base64 import gzip import json -import urllib.parse import pytest import brotli @@ -13,142 +12,15 @@ @pytest.mark.parametrize( - "mock_http_event,query_string", - [ - (["GET", None, None], b""), - (["GET", None, {"name": ["me"]}], b"name=me"), - (["GET", None, {"name": ["me", "you"]}], b"name=me&name=you"), - ( - ["GET", None, {"name": ["me", "you"], "pet": ["dog"]}], - b"name=me&name=you&pet=dog", - ), - ], - indirect=["mock_http_event"], -) -def test_http_request(mock_http_event, query_string) -> None: - async def app(scope, receive, send): - assert scope == { - "asgi": {"version": "3.0"}, - "aws.context": {}, - "aws.event": { - "body": None, - "headers": { - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", - "Accept-Encoding": "gzip, deflate, lzma, sdch, " "br", - "Accept-Language": "en-US,en;q=0.8", - "CloudFront-Forwarded-Proto": "https", - "CloudFront-Is-Desktop-Viewer": "true", - "CloudFront-Is-Mobile-Viewer": "false", - "CloudFront-Is-SmartTV-Viewer": "false", - "CloudFront-Is-Tablet-Viewer": "false", - "CloudFront-Viewer-Country": "US", - "Cookie": "cookie1; cookie2", - "Host": "test.execute-api.us-west-2.amazonaws.com", - "Upgrade-Insecure-Requests": "1", - "X-Forwarded-For": "192.168.100.1, 192.168.1.1", - "X-Forwarded-Port": "443", - "X-Forwarded-Proto": "https", - }, - "httpMethod": "GET", - "path": "/test/hello", - "pathParameters": {"proxy": "hello"}, - "queryStringParameters": mock_http_event["queryStringParameters"], - "multiValueQueryStringParameters": mock_http_event[ - "multiValueQueryStringParameters" - ], - "requestContext": { - "accountId": "123456789012", - "apiId": "123", - "httpMethod": "GET", - "identity": { - "accountId": "", - "apiKey": "", - "caller": "", - "cognitoAuthenticationProvider": "", - "cognitoAuthenticationType": "", - "cognitoIdentityId": "", - "cognitoIdentityPoolId": "", - "sourceIp": "192.168.100.1", - "user": "", - "userAgent": "Mozilla/5.0 " - "(Macintosh; " - "Intel Mac OS " - "X 10_11_6) " - "AppleWebKit/537.36 " - "(KHTML, like " - "Gecko) " - "Chrome/52.0.2743.82 " - "Safari/537.36 " - "OPR/39.0.2256.48", - "userArn": "", - }, - "requestId": "41b45ea3-70b5-11e6-b7bd-69b5aaebc7d9", - "resourceId": "us4z18", - "resourcePath": "/{proxy+}", - "stage": "Prod", - }, - "resource": "/{proxy+}", - "stageVariables": {"stageVarName": "stageVarValue"}, - }, - "client": ("192.168.100.1", 0), - "headers": [ - [ - b"accept", - b"text/html,application/xhtml+xml,application/xml;q=0.9,image/" - b"webp,*/*;q=0.8", - ], - [b"accept-encoding", b"gzip, deflate, lzma, sdch, br"], - [b"accept-language", b"en-US,en;q=0.8"], - [b"cloudfront-forwarded-proto", b"https"], - [b"cloudfront-is-desktop-viewer", b"true"], - [b"cloudfront-is-mobile-viewer", b"false"], - [b"cloudfront-is-smarttv-viewer", b"false"], - [b"cloudfront-is-tablet-viewer", b"false"], - [b"cloudfront-viewer-country", b"US"], - [b"cookie", b"cookie1; cookie2"], - [b"host", b"test.execute-api.us-west-2.amazonaws.com"], - [b"upgrade-insecure-requests", b"1"], - [b"x-forwarded-for", b"192.168.100.1, 192.168.1.1"], - [b"x-forwarded-port", b"443"], - [b"x-forwarded-proto", b"https"], - ], - "http_version": "1.1", - "method": "GET", - "path": "/test/hello", - "query_string": query_string, - "raw_path": None, - "root_path": "", - "scheme": "https", - "server": ("test.execute-api.us-west-2.amazonaws.com", 443), - "type": "http", - } - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [[b"content-type", b"text/plain; charset=utf-8"]], - } - ) - await send({"type": "http.response.body", "body": b"Hello, world!"}) - - handler = Mangum(app, lifespan="off") - - response = handler(mock_http_event, {}) - assert response == { - "statusCode": 200, - "isBase64Encoded": False, - "headers": {"content-type": "text/plain; charset=utf-8"}, - "body": "Hello, world!", - } - - -@pytest.mark.parametrize( - "mock_http_event", [["GET", None, {"name": ["me", "you"]}]], indirect=True + "mock_aws_api_gateway_event", + [["GET", None, {"name": ["me", "you"]}]], + indirect=True, ) -def test_http_response(mock_http_event) -> None: +def test_http_response(mock_aws_api_gateway_event) -> None: async def app(scope, receive, send): assert scope == { "asgi": {"version": "3.0"}, + "aws.eventType": "AWS_API_GATEWAY", "aws.context": {}, "aws.event": { "body": None, @@ -254,7 +126,7 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) handler = Mangum(app, lifespan="off") - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response == { "statusCode": 200, "isBase64Encoded": False, @@ -267,289 +139,30 @@ async def app(scope, receive, send): @pytest.mark.parametrize( - "mock_http_elb_singlevalue_event", - [["GET", None, {"name": ["me", "you"]}]], - indirect=True, + "mock_aws_api_gateway_event", [["GET", None, None]], indirect=True ) -def test_elb_singlevalue_http_response(mock_http_elb_singlevalue_event) -> None: - async def app(scope, receive, send): - assert scope == { - "asgi": {"version": "3.0"}, - "aws.context": {}, - "aws.event": { - "body": None, - "isBase64Encoded": False, - "headers": { - "accept-encoding": "gzip, deflate", - "cookie": "cookie1; cookie2", - "host": "test.execute-api.us-west-2.amazonaws.com", - "x-forwarded-for": "192.168.100.3, 192.168.100.2, 192.168.100.1", - "x-forwarded-port": "443", - "x-forwarded-proto": "https", - }, - "httpMethod": "GET", - "path": "/my/path", - "queryStringParameters": {"name": "you"}, - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:0:targetgroup/test/0" - } - }, - }, - "client": ("192.168.100.1", 0), - "headers": [ - [b"accept-encoding", b"gzip, deflate"], - [b"cookie", b"cookie1; cookie2"], - [b"host", b"test.execute-api.us-west-2.amazonaws.com"], - [b"x-forwarded-for", b"192.168.100.3, 192.168.100.2, 192.168.100.1"], - [b"x-forwarded-port", b"443"], - [b"x-forwarded-proto", b"https"], - ], - "http_version": "1.1", - "method": "GET", - "path": "/my/path", - "query_string": b"name=you", - "raw_path": None, - "root_path": "", - "scheme": "https", - "server": ("test.execute-api.us-west-2.amazonaws.com", 443), - "type": "http", - } - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [ - [b"content-type", b"text/plain; charset=utf-8"], - [b"set-cookie", b"cookie1=cookie1; Secure"], - [b"set-cookie", b"cookie2=cookie2; Secure"], - [b"set-cookie", b"cookie3=cookie3; Secure"], - ], - } - ) - await send({"type": "http.response.body", "body": b"Hello, world!"}) - - handler = Mangum(app, lifespan="off") - response = handler(mock_http_elb_singlevalue_event, {}) - assert response == { - "statusCode": 200, - "isBase64Encoded": False, - "headers": { - "content-type": "text/plain; charset=utf-8", - "set-cookie": "cookie1=cookie1; Secure", - "Set-cookie": "cookie2=cookie2; Secure", - "sEt-cookie": "cookie3=cookie3; Secure", - }, - "body": "Hello, world!", - } - - -@pytest.mark.parametrize( - "mock_http_elb_multivalue_event", - [["GET", None, {"name": ["me", "you"]}]], - indirect=True, -) -def test_elb_multivalue_http_response(mock_http_elb_multivalue_event) -> None: - async def app(scope, receive, send): - assert scope == { - "asgi": {"version": "3.0"}, - "aws.context": {}, - "aws.event": { - "body": None, - "isBase64Encoded": False, - "multiValueHeaders": { - "accept-encoding": ["gzip, deflate"], - "cookie": ["cookie1; cookie2"], - "host": ["test.execute-api.us-west-2.amazonaws.com"], - "x-forwarded-for": ["192.168.100.3, 192.168.100.2, 192.168.100.1"], - "x-forwarded-port": ["443"], - "x-forwarded-proto": ["https"], - }, - "httpMethod": "GET", - "path": "/my/path", - "multiValueQueryStringParameters": {"name": ["me", "you"]}, - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:0:targetgroup/test/0" - } - }, - }, - "client": ("192.168.100.1", 0), - "headers": [ - [b"accept-encoding", b"gzip, deflate"], - [b"cookie", b"cookie1; cookie2"], - [b"host", b"test.execute-api.us-west-2.amazonaws.com"], - [b"x-forwarded-for", b"192.168.100.3, 192.168.100.2, 192.168.100.1"], - [b"x-forwarded-port", b"443"], - [b"x-forwarded-proto", b"https"], - ], - "http_version": "1.1", - "method": "GET", - "path": "/my/path", - "query_string": b"name=me&name=you", - "raw_path": None, - "root_path": "", - "scheme": "https", - "server": ("test.execute-api.us-west-2.amazonaws.com", 443), - "type": "http", - } - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [ - [b"content-type", b"text/plain; charset=utf-8"], - [b"set-cookie", b"cookie1=cookie1; Secure"], - [b"set-cookie", b"cookie2=cookie2; Secure"], - ], - } - ) - await send({"type": "http.response.body", "body": b"Hello, world!"}) - - handler = Mangum(app, lifespan="off") - response = handler(mock_http_elb_multivalue_event, {}) - assert response == { - "statusCode": 200, - "isBase64Encoded": False, - "headers": {}, - "multiValueHeaders": { - "content-type": ["text/plain; charset=utf-8"], - "set-cookie": ["cookie1=cookie1; Secure", "cookie2=cookie2; Secure"], - }, - "body": "Hello, world!", - } - - -@pytest.mark.parametrize("mock_http_event", [["GET", "123", None]], indirect=True) -def test_http_response_with_body(mock_http_event) -> None: - async def app(scope, receive, send): - assert scope["type"] == "http" - - body = [b"4", b"5", b"6"] - - while True: - message = await receive() - if "body" in message: - body.append(message["body"]) - - if not message.get("more_body", False): - body = b"".join(body) - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [[b"content-type", b"text/plain; charset=utf-8"]], - } - ) - await send({"type": "http.response.body", "body": body}) - return - - handler = Mangum(app, lifespan="off") - response = handler(mock_http_event, {}) - - assert response == { - "statusCode": 200, - "isBase64Encoded": False, - "headers": {"content-type": "text/plain; charset=utf-8"}, - "body": "456123", - } - - -@pytest.mark.parametrize( - "mock_http_event", [["GET", base64.b64encode(b"123"), None]], indirect=True -) -def test_http_binary_request_with_body(mock_http_event) -> None: - async def app(scope, receive, send): - assert scope["type"] == "http" - - body = [] - message = await receive() - - if "body" in message: - body.append(message["body"]) - - if not message.get("more_body", False): - - body = b"".join(body) - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [[b"content-type", b"text/plain; charset=utf-8"]], - } - ) - await send({"type": "http.response.body", "body": body}) - - mock_http_event["isBase64Encoded"] = True - handler = Mangum(app, lifespan="off") - response = handler(mock_http_event, {}) - - assert response == { - "statusCode": 200, - "isBase64Encoded": False, - "headers": {"content-type": "text/plain; charset=utf-8"}, - "body": "123", - } - - -@pytest.mark.parametrize( - "mock_http_event", [["GET", base64.b64encode(b"123"), None]], indirect=True -) -def test_http_binary_request_and_response(mock_http_event) -> None: - async def app(scope, receive, send): - assert scope["type"] == "http" - - body = [] - message = await receive() - - if "body" in message: - body.append(message["body"]) - - if not message.get("more_body", False): - - body = b"".join(body) - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [[b"content-type", b"application/octet-stream"]], - } - ) - await send({"type": "http.response.body", "body": b"abc"}) - - mock_http_event["isBase64Encoded"] = True - handler = Mangum(app, lifespan="off") - response = handler(mock_http_event, {}) - - assert response == { - "statusCode": 200, - "isBase64Encoded": True, - "headers": {"content-type": "application/octet-stream"}, - "body": base64.b64encode(b"abc").decode(), - } - - -@pytest.mark.parametrize("mock_http_event", [["GET", None, None]], indirect=True) -def test_http_exception(mock_http_event) -> None: +def test_http_exception_mid_response(mock_aws_api_gateway_event) -> None: async def app(scope, receive, send): await send({"type": "http.response.start", "status": 200}) raise Exception() - await send({"type": "http.response.body", "body": b"1", "more_body": True}) handler = Mangum(app, lifespan="off") - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response == { "body": "Internal Server Error", "headers": {"content-type": "text/plain; charset=utf-8"}, "isBase64Encoded": False, + "multiValueHeaders": {}, "statusCode": 500, } -@pytest.mark.parametrize("mock_http_event", [["GET", None, None]], indirect=True) -def test_http_exception_handler(mock_http_event) -> None: - path = mock_http_event["path"] +@pytest.mark.parametrize( + "mock_aws_api_gateway_event", [["GET", None, None]], indirect=True +) +def test_http_exception_handler(mock_aws_api_gateway_event) -> None: + path = mock_aws_api_gateway_event["path"] app = Starlette() @app.exception_handler(Exception) @@ -562,27 +175,31 @@ def homepage(request): return PlainTextResponse("Hello, world!") handler = Mangum(app) - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response == { "body": "Error!", "headers": {"content-length": "6", "content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, "isBase64Encoded": False, "statusCode": 500, } -@pytest.mark.parametrize("mock_http_event", [["GET", "", None]], indirect=True) -def test_http_cycle_state(mock_http_event) -> None: +@pytest.mark.parametrize( + "mock_aws_api_gateway_event", [["GET", "", None]], indirect=True +) +def test_http_cycle_state(mock_aws_api_gateway_event) -> None: async def app(scope, receive, send): assert scope["type"] == "http" await send({"type": "http.response.body", "body": b"Hello, world!"}) handler = Mangum(app, lifespan="off") - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response == { "body": "Internal Server Error", "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, "isBase64Encoded": False, "statusCode": 500, } @@ -594,80 +211,20 @@ async def app(scope, receive, send): handler = Mangum(app, lifespan="off") - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response == { "body": "Internal Server Error", "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, "isBase64Encoded": False, "statusCode": 500, } -@pytest.mark.parametrize("mock_http_event", [["GET", "", None]], indirect=True) -def test_http_api_gateway_base_path(mock_http_event) -> None: - async def app(scope, receive, send): - assert scope["type"] == "http" - assert scope["path"] == urllib.parse.unquote(mock_http_event["path"]) - await send({"type": "http.response.start", "status": 200}) - await send({"type": "http.response.body", "body": b"Hello world!"}) - - handler = Mangum(app, lifespan="off", api_gateway_base_path=None) - response = handler(mock_http_event, {}) - - assert response == { - "body": "Hello world!", - "headers": {}, - "isBase64Encoded": False, - "statusCode": 200, - } - - async def app(scope, receive, send): - assert scope["type"] == "http" - assert scope["path"] == urllib.parse.unquote( - mock_http_event["path"][len(f"/{api_gateway_base_path}") :] - ) - await send({"type": "http.response.start", "status": 200}) - await send({"type": "http.response.body", "body": b"Hello world!"}) - - api_gateway_base_path = "test" - handler = Mangum(app, lifespan="off", api_gateway_base_path=api_gateway_base_path) - response = handler(mock_http_event, {}) - assert response == { - "body": "Hello world!", - "headers": {}, - "isBase64Encoded": False, - "statusCode": 200, - } - - -@pytest.mark.parametrize("mock_http_event", [["GET", "", None]], indirect=True) -def test_http_text_mime_types(mock_http_event) -> None: - async def app(scope, receive, send): - assert scope["type"] == "http" - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [[b"content-type", b"text/plain; charset=utf-8"]], - } - ) - await send({"type": "http.response.body", "body": b"Hello, world!"}) - - handler = Mangum( - app, lifespan="off", text_mime_types=["application/vnd.apple.pkpass"] - ) - response = handler(mock_http_event, {}) - - assert response == { - "statusCode": 200, - "isBase64Encoded": False, - "headers": {"content-type": "text/plain; charset=utf-8"}, - "body": "Hello, world!", - } - - -@pytest.mark.parametrize("mock_http_event", [["GET", "", None]], indirect=True) -def test_http_binary_gzip_response(mock_http_event) -> None: +@pytest.mark.parametrize( + "mock_aws_api_gateway_event", [["GET", b"", None]], indirect=True +) +def test_http_binary_gzip_response(mock_aws_api_gateway_event) -> None: body = json.dumps({"abc": "defg"}) async def app(scope, receive, send): @@ -683,7 +240,7 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": body.encode()}) handler = Mangum(GZipMiddleware(app, minimum_size=1), lifespan="off") - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response["isBase64Encoded"] assert response["headers"] == { @@ -712,10 +269,11 @@ async def app(scope, receive, send): ], indirect=["mock_http_api_event"], ) -def test_api_request(mock_http_api_event) -> None: +def test_set_cookies(mock_http_api_event) -> None: async def app(scope, receive, send): assert scope == { "asgi": {"version": "3.0"}, + "aws.eventType": "AWS_HTTP_GATEWAY", "aws.context": {}, "aws.event": { "version": "2.0", @@ -800,13 +358,17 @@ async def app(scope, receive, send): "statusCode": 200, "isBase64Encoded": False, "headers": {"content-type": "text/plain; charset=utf-8"}, - "cookies": ["cookie1=cookie1; Secure", "cookie2=cookie2; Secure"], + "multiValueHeaders": { + "set-cookie": ["cookie1=cookie1; Secure", "cookie2=cookie2; Secure"] + }, "body": "Hello, world!", } -@pytest.mark.parametrize("mock_http_event", [["GET", "", None]], indirect=True) -def test_http_empty_header(mock_http_event) -> None: +@pytest.mark.parametrize( + "mock_aws_api_gateway_event", [["GET", "", None]], indirect=True +) +def test_http_empty_header(mock_aws_api_gateway_event) -> None: async def app(scope, receive, send): assert scope["type"] == "http" await send( @@ -820,19 +382,20 @@ async def app(scope, receive, send): handler = Mangum(app, lifespan="off") - mock_http_event["headers"] = None + mock_aws_api_gateway_event["headers"] = None - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response == { "statusCode": 200, "isBase64Encoded": False, "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, "body": "Hello, world!", } @pytest.mark.parametrize( - "mock_http_event,response_headers,expected_headers,expected_multi_value_headers", + "mock_aws_api_gateway_event,response_headers,expected_headers,expected_multi_value_headers", [ [ ["GET", None, None], @@ -854,11 +417,14 @@ async def app(scope, receive, send): ], [["GET", None, None], [], {}, {}], ], - indirect=["mock_http_event"], + indirect=["mock_aws_api_gateway_event"], ) def test_http_response_headers( - mock_http_event, response_headers, expected_headers, expected_multi_value_headers -) -> None: + mock_aws_api_gateway_event, + response_headers, + expected_headers, + expected_multi_value_headers, +): async def app(scope, receive, send): await send( { @@ -871,11 +437,12 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) handler = Mangum(app, lifespan="off") - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) expected = { "statusCode": 200, "isBase64Encoded": False, "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, "body": "Hello, world!", } if expected_headers: @@ -885,8 +452,10 @@ async def app(scope, receive, send): assert response == expected -@pytest.mark.parametrize("mock_http_event", [["GET", "", None]], indirect=True) -def test_http_binary_br_response(mock_http_event) -> None: +@pytest.mark.parametrize( + "mock_aws_api_gateway_event", [["GET", "", None]], indirect=True +) +def test_http_binary_br_response(mock_aws_api_gateway_event) -> None: body = json.dumps({"abc": "defg"}) async def app(scope, receive, send): @@ -902,7 +471,7 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": body.encode()}) handler = Mangum(BrotliMiddleware(app, minimum_size=1), lifespan="off") - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response["isBase64Encoded"] assert response["headers"] == { diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index f6d23e0e..af02b651 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -19,15 +19,15 @@ @pytest.mark.parametrize( - "mock_http_event,lifespan", + "mock_aws_api_gateway_event,lifespan", [ (["GET", None, None], "auto"), (["GET", None, None], "on"), (["GET", None, None], "off"), ], - indirect=["mock_http_event"], + indirect=["mock_aws_api_gateway_event"], ) -def test_lifespan(mock_http_event, lifespan) -> None: +def test_lifespan(mock_aws_api_gateway_event, lifespan) -> None: """ Test each lifespan option using an application that supports lifespan messages. @@ -76,7 +76,7 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) handler = Mangum(app, lifespan=lifespan) - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) expected = lifespan in ("on", "auto") assert startup_complete == expected @@ -85,20 +85,21 @@ async def app(scope, receive, send): "statusCode": 200, "isBase64Encoded": False, "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, "body": "Hello, world!", } @pytest.mark.parametrize( - "mock_http_event,lifespan", + "mock_aws_api_gateway_event,lifespan", [ (["GET", None, None], "auto"), (["GET", None, None], "on"), (["GET", None, None], "off"), ], - indirect=["mock_http_event"], + indirect=["mock_aws_api_gateway_event"], ) -def test_lifespan_unsupported(mock_http_event, lifespan) -> None: +def test_lifespan_unsupported(mock_aws_api_gateway_event, lifespan) -> None: """ Test each lifespan option with an application that does not support lifespan events. """ @@ -114,22 +115,23 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) handler = Mangum(app, lifespan=lifespan) - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert response == { "statusCode": 200, "isBase64Encoded": False, "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, "body": "Hello, world!", } @pytest.mark.parametrize( - "mock_http_event,lifespan", + "mock_aws_api_gateway_event,lifespan", [(["GET", None, None], "auto"), (["GET", None, None], "on")], - indirect=["mock_http_event"], + indirect=["mock_aws_api_gateway_event"], ) -def test_lifespan_error(mock_http_event, lifespan, caplog) -> None: +def test_lifespan_error(mock_aws_api_gateway_event, lifespan, caplog) -> None: caplog.set_level(logging.ERROR) async def app(scope, receive, send): @@ -149,23 +151,24 @@ async def app(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) handler = Mangum(app, lifespan=lifespan) - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert "Exception in 'lifespan' protocol." in caplog.text assert response == { "statusCode": 200, "isBase64Encoded": False, "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, "body": "Hello, world!", } @pytest.mark.parametrize( - "mock_http_event,lifespan", + "mock_aws_api_gateway_event,lifespan", [(["GET", None, None], "auto"), (["GET", None, None], "on")], - indirect=["mock_http_event"], + indirect=["mock_aws_api_gateway_event"], ) -def test_lifespan_unexpected_message(mock_http_event, lifespan) -> None: +def test_lifespan_unexpected_message(mock_aws_api_gateway_event, lifespan) -> None: async def app(scope, receive, send): if scope["type"] == "lifespan": while True: @@ -183,20 +186,20 @@ async def app(scope, receive, send): handler = Mangum(app, lifespan=lifespan) with pytest.raises(LifespanFailure): - handler(mock_http_event, {}) + handler(mock_aws_api_gateway_event, {}) @pytest.mark.parametrize( - "mock_http_event,lifespan,failure_type", + "mock_aws_api_gateway_event,lifespan,failure_type", [ (["GET", None, None], "auto", "startup"), (["GET", None, None], "on", "startup"), (["GET", None, None], "auto", "shutdown"), (["GET", None, None], "on", "shutdown"), ], - indirect=["mock_http_event"], + indirect=["mock_aws_api_gateway_event"], ) -def test_lifespan_failure(mock_http_event, lifespan, failure_type) -> None: +def test_lifespan_failure(mock_aws_api_gateway_event, lifespan, failure_type) -> None: async def app(scope, receive, send): if scope["type"] == "lifespan": while True: @@ -218,15 +221,17 @@ async def app(scope, receive, send): handler = Mangum(app, lifespan=lifespan) with pytest.raises(LifespanFailure): - handler(mock_http_event, {}) + handler(mock_aws_api_gateway_event, {}) -@pytest.mark.parametrize("mock_http_event", [["GET", None, None]], indirect=True) -def test_starlette_lifespan(mock_http_event) -> None: +@pytest.mark.parametrize( + "mock_aws_api_gateway_event", [["GET", None, None]], indirect=True +) +def test_starlette_lifespan(mock_aws_api_gateway_event) -> None: startup_complete = False shutdown_complete = False - path = mock_http_event["path"] + path = mock_aws_api_gateway_event["path"] app = Starlette() @app.on_event("startup") @@ -247,9 +252,9 @@ def homepage(request): assert not shutdown_complete handler = Mangum(app) - mock_http_event["body"] = None + mock_aws_api_gateway_event["body"] = None - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert startup_complete assert shutdown_complete assert response == { @@ -259,6 +264,7 @@ def homepage(request): "content-length": "13", "content-type": "text/plain; charset=utf-8", }, + "multiValueHeaders": {}, "body": "Hello, world!", } @@ -267,11 +273,13 @@ def homepage(request): IS_PY38, reason="One (or more) of Quart's dependencies does not support Python 3.8." ) @pytest.mark.skipif(IS_PY36, reason="Quart does not support Python 3.6.") -@pytest.mark.parametrize("mock_http_event", [["GET", None, None]], indirect=True) -def test_quart_lifespan(mock_http_event) -> None: +@pytest.mark.parametrize( + "mock_aws_api_gateway_event", [["GET", None, None]], indirect=True +) +def test_quart_lifespan(mock_aws_api_gateway_event) -> None: startup_complete = False shutdown_complete = False - path = mock_http_event["path"] + path = mock_aws_api_gateway_event["path"] app = Quart(__name__) @app.before_serving @@ -292,7 +300,7 @@ async def hello(): assert not shutdown_complete handler = Mangum(app) - response = handler(mock_http_event, {}) + response = handler(mock_aws_api_gateway_event, {}) assert startup_complete assert shutdown_complete @@ -300,5 +308,6 @@ async def hello(): "statusCode": 200, "isBase64Encoded": False, "headers": {"content-length": "12", "content-type": "text/html; charset=utf-8"}, + "multiValueHeaders": {}, "body": "hello world!", }