Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
max-line-length = 120
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flake8 configuration is taken care of in setup.cfg, so I don't think we need this. In any case, the max-line-length should be 88.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah there it is! I was wondering where you got 88 from? It seems old convention was 80 and modern convention is 120 to 150. Why 88?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left it at 120 😈 I am curious though, I never know if there is an official source for stuff like this. Conventions and multiple files at once browsing is good. I can comfortably fit 2 120 files on my 2K monitor ¯_(ツ)_/¯

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

88 is the default for python black: https://pypi.org/project/black/

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the default for black as dltacube mentioned, and it is also the max line length for this project. I personally find beyond this uncomfortable to read/review, and I think when lines start getting too long it generally indicates something should be refactored (though this is the less important reason).

Please change it back to 88.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really it's only longer URLs that span the 88 characters. It's brutal to cut everything down so small, lots of weird multiline strings result. I changed it, but it's uglier.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I know it may seem nitpicky or pedantic, but it really does make it much easier for me to review contributions with this line length.

13 changes: 11 additions & 2 deletions docs/http.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
# HTTP

Mangum provides support for both [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) and the newer [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) APIs in API Gateway. It also includes configurable binary response support.

Mangum provides support for the following AWS HTTP Lambda Event Source:

* [API Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html)
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/services-apigateway.html))
* [HTTP Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html)
([Event Examples](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html))
* [Application Load Balancer (ALB)](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/introduction.html)
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html))
* [CloudFront Lambda@Edge](https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/lambda-at-the-edge.html)
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/lambda-edge.html))

```python
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
Expand Down
4 changes: 4 additions & 0 deletions mangum/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .response import Response
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move Response and Scope to a single file. The reason for doing so isn't for the imports, rather it is easier to maintain a single file with related classes like this.

Copy link
Copy Markdown
Contributor Author

@four43 four43 Mar 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I was able to clean up adapter so much, they might fit nicely in adapter... but now circular references.

from .scope import Scope
from .adapter import Mangum # noqa: F401

__all__ = ["Mangum", "Response", "Scope"]
175 changes: 26 additions & 149 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@
import base64
import typing
import logging
import urllib.parse

from dataclasses import dataclass, InitVar
from contextlib import ExitStack
from typing import Any, ContextManager, Dict, TYPE_CHECKING

from mangum.types import ASGIApp, Scope
from mangum.protocols.lifespan import LifespanCycle
from mangum.protocols.http import HTTPCycle
from mangum.exceptions import ConfigurationError
from .exceptions import ConfigurationError
from .handlers import AbstractHandler
from .protocols import HTTPCycle, LifespanCycle
from .types import ASGIApp

if typing.TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING: # pragma: no cover
from awslambdaric.lambda_context import LambdaContext

DEFAULT_TEXT_MIME_TYPES = [
"text/",
"application/json",
"application/javascript",
"application/xml",
"application/vnd.api+json",
]

LOG_LEVELS = {
"critical": logging.CRITICAL,
"error": logging.ERROR,
"warning": logging.WARNING,
"info": logging.INFO,
"debug": logging.DEBUG,
}
logger = logging.getLogger("mangum")


@dataclass
class Mangum:
"""
Creates an adapter instance.
Expand All @@ -41,153 +31,40 @@ class Mangum:
and `off`. Default is `auto`.
* **log_level** - A string to configure the log level. Choices are: `info`,
`critical`, `error`, `warning`, and `debug`. Default is `info`.
* **api_gateway_base_path** - Base path to strip from URL when using a custom
domain name.
* **text_mime_types** - A list of MIME types to include with the defaults that
should not return a binary response in API Gateway.
"""

app: ASGIApp
lifespan: str = "auto"
log_level: str = "info"
api_gateway_base_path: typing.Optional[str] = None
text_mime_types: InitVar[typing.Optional[typing.List[str]]] = None

def __post_init__(self, text_mime_types: typing.Optional[typing.List[str]]) -> None:
def __init__(
self,
app: ASGIApp,
lifespan: str = "auto",
**handler_kwargs: Dict[str, Any],
):
self.app = app
self.lifespan = lifespan
self.handler_kwargs = handler_kwargs

if self.lifespan not in ("auto", "on", "off"):
raise ConfigurationError(
"Invalid argument supplied for `lifespan`. Choices are: auto|on|off"
)

if self.log_level not in ("critical", "error", "warning", "info", "debug"):
raise ConfigurationError(
"Invalid argument supplied for `log_level`. "
"Choices are: critical|error|warning|info|debug"
)

self.logger = logging.getLogger("mangum")
self.logger.setLevel(LOG_LEVELS[self.log_level])

should_prefix_base_path = (
self.api_gateway_base_path
and not self.api_gateway_base_path.startswith("/")
)
if should_prefix_base_path:
self.api_gateway_base_path = f"/{self.api_gateway_base_path}"

if text_mime_types:
text_mime_types += DEFAULT_TEXT_MIME_TYPES
else:
text_mime_types = DEFAULT_TEXT_MIME_TYPES
self.text_mime_types = text_mime_types

def __call__(self, event: dict, context: "LambdaContext") -> dict:
self.logger.debug("Event received.")
logger.debug("Event received.")

with ExitStack() as stack:
if self.lifespan != "off":
lifespan_cycle: typing.ContextManager = LifespanCycle(
self.app, self.lifespan
)
lifespan_cycle: ContextManager = LifespanCycle(self.app, self.lifespan)
stack.enter_context(lifespan_cycle)

is_binary = event.get("isBase64Encoded", False)
initial_body = event.get("body") or b""
if is_binary:
initial_body = base64.b64decode(initial_body)
elif not isinstance(initial_body, bytes):
initial_body = initial_body.encode()

scope = self.create_scope(event, context)
http_cycle = HTTPCycle(scope, text_mime_types=self.text_mime_types)
response = http_cycle(self.app, initial_body)

return response

def create_scope(self, event: dict, context: "LambdaContext") -> Scope:
"""
Creates a scope object according to ASGI specification from a Lambda Event.

https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope

The event comes from various sources: AWS ALB, AWS API Gateway of different
versions and configurations(multivalue header, etc).
Thus, some heuristics is applied to guess an event type.

"""
request_context = event["requestContext"]

if event.get("multiValueHeaders"):
headers = {
k.lower(): ", ".join(v) if isinstance(v, list) else ""
for k, v in event.get("multiValueHeaders", {}).items()
}
elif event.get("headers"):
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
else:
headers = {}

# API Gateway v2
if event.get("version") == "2.0":
source_ip = request_context["http"]["sourceIp"]
path = request_context["http"]["path"]
http_method = request_context["http"]["method"]
query_string = event.get("rawQueryString", "").encode()

if event.get("cookies"):
headers["cookie"] = "; ".join(event.get("cookies", []))

# API Gateway v1 / ELB
else:
if "elb" in request_context:
# NOTE: trust only the most right side value
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
else:
source_ip = request_context.get("identity", {}).get("sourceIp")

path = event["path"]
http_method = event["httpMethod"]

if event.get("multiValueQueryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("multiValueQueryStringParameters", {}), doseq=True
).encode()
elif event.get("queryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("queryStringParameters", {})
).encode()
else:
query_string = b""

server_name = headers.get("host", "mangum")
if ":" not in server_name:
server_port = headers.get("x-forwarded-port", 80)
else:
server_name, server_port = server_name.split(":") # pragma: no cover
server = (server_name, int(server_port))
client = (source_ip, 0)

if not path: # pragma: no cover
path = "/"
elif self.api_gateway_base_path:
if path.startswith(self.api_gateway_base_path):
path = path[len(self.api_gateway_base_path) :]

scope = {
"type": "http",
"http_version": "1.1",
"method": http_method,
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
"path": urllib.parse.unquote(path),
"raw_path": None,
"root_path": "",
"scheme": headers.get("x-forwarded-proto", "https"),
"query_string": query_string,
"server": server,
"client": client,
"asgi": {"version": "3.0"},
"aws.event": event,
"aws.context": context,
}
handler = AbstractHandler.from_trigger(
event, context, **self.handler_kwargs
)
http_cycle = HTTPCycle(handler.scope)
response = http_cycle(self.app, handler.body)

return scope
return handler.transform_response(response)
13 changes: 13 additions & 0 deletions mangum/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .abstract_handler import AbstractHandler
from .aws_alb import AwsAlb
from .aws_api_gateway import AwsApiGateway
from .aws_cf_lambda_at_edge import AwsCfLambdaAtEdge
from .aws_http_gateway import AwsHttpGateway

__all__ = [
"AbstractHandler",
"AwsAlb",
"AwsApiGateway",
"AwsCfLambdaAtEdge",
"AwsHttpGateway",
]
130 changes: 130 additions & 0 deletions mangum/handlers/abstract_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import base64
from abc import ABCMeta, abstractmethod
from typing import Dict, Any, TYPE_CHECKING, Tuple, List

from .. import Response, Scope

if TYPE_CHECKING: # pragma: no cover
from awslambdaric.lambda_context import LambdaContext


class AbstractHandler(metaclass=ABCMeta):
def __init__(
self,
trigger_event: Dict[str, Any],
trigger_context: "LambdaContext",
**kwargs: Dict[str, Any]
):
self.trigger_event = trigger_event
self.trigger_context = trigger_context

@property
@abstractmethod
def scope(self) -> Scope:
"""
Parse an ASGI scope from the request event
"""

@property
@abstractmethod
def body(self) -> bytes:
"""
Get the raw body from the request event
"""

@abstractmethod
def transform_response(self, response: Response) -> Dict[str, Any]:
"""
After running our application, transform the response to the correct format for this handler
"""

@staticmethod
def from_trigger(
trigger_event: Dict[str, Any],
trigger_context: "LambdaContext",
**kwargs: Dict[str, Any]
) -> "AbstractHandler":
"""
A factory method that determines which handler to use. All this code should probably stay in one place to make
sure we are able to uniquely find each handler correctly.
"""

# These should be ordered from most specific to least for best accuracy
if (
"requestContext" in trigger_event
and "elb" in trigger_event["requestContext"]
):
from . import AwsAlb

return AwsAlb(trigger_event, trigger_context, **kwargs)

if (
"Records" in trigger_event
and len(trigger_event["Records"]) > 0
and "cf" in trigger_event["Records"][0]
):
from . import AwsCfLambdaAtEdge

return AwsCfLambdaAtEdge(trigger_event, trigger_context, **kwargs)

if "version" in trigger_event and "requestContext" in trigger_event:
from . import AwsHttpGateway

return AwsHttpGateway(trigger_event, trigger_context, **kwargs)

if "resource" in trigger_event:
from . import AwsApiGateway

return AwsApiGateway(trigger_event, trigger_context, **kwargs) # type: ignore

raise TypeError("Unable to determine handler from trigger event")

@staticmethod
def _handle_multi_value_headers(
response_headers: List[List[bytes]],
) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
headers: Dict[str, str] = {}
multi_value_headers: Dict[str, List[str]] = {}
for key, value in response_headers:
lower_key = key.decode().lower()
if lower_key in multi_value_headers:
multi_value_headers[lower_key].append(value.decode())
elif lower_key in headers:
# Move existing to multi_value_headers and append current
multi_value_headers[lower_key] = [
headers[lower_key],
value.decode(),
]
del headers[lower_key]
else:
headers[lower_key] = value.decode()
return headers, multi_value_headers

@staticmethod
def _handle_base64_response_body(
body: bytes, headers: Dict[str, str]
) -> Tuple[str, bool]:
"""
To ease debugging for our users, try and return strings where we can, otherwise to ensure maximum
compatibility with binary data, base64 encode it.
"""
is_base64_encoded = False
output_body = ""
if body != b"":
from ..adapter import DEFAULT_TEXT_MIME_TYPES

for text_mime_type in DEFAULT_TEXT_MIME_TYPES:
if text_mime_type in headers.get("content-type", ""):
try:
output_body = body.decode()
except UnicodeDecodeError:
# Can't decode it, base64 it and be done
output_body = base64.b64encode(body).decode()
is_base64_encoded = True
break
else:
# Not text, base64 encode
output_body = base64.b64encode(body).decode()
is_base64_encoded = True

return output_body, is_base64_encoded
Loading