Skip to content

Commit 7a4134e

Browse files
committed
Refactor handlers to be separate from core logic
1 parent c6a4b35 commit 7a4134e

23 files changed

+2351
-814
lines changed

docs/http.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
# HTTP
22

3-
Mangum provides support for both [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) and the newer [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) APIs in API Gateway. It also includes configurable binary response support.
4-
3+
Mangum provides support for the following AWS HTTP Lambda Event Source:
4+
5+
* [API Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html)
6+
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/services-apigateway.html))
7+
* [HTTP Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html)
8+
([Event Examples](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html))
9+
* [Application Load Balancer (ALB)](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/introduction.html)
10+
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html))
11+
* [CloudFront Lambda@Edge](https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/lambda-at-the-edge.html)
12+
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/lambda-edge.html))
13+
514
```python
615
from fastapi import FastAPI
716
from fastapi.middleware.gzip import GZipMiddleware

mangum/adapter.py

Lines changed: 25 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,32 @@
11
import base64
2-
import typing
2+
from typing import Any, Callable, ContextManager, Dict, Optional, List, TYPE_CHECKING
33
import logging
44
import urllib.parse
55

66
from dataclasses import dataclass, InitVar
77
from contextlib import ExitStack
88

9-
from mangum.types import ASGIApp, Scope
9+
from mangum.handlers import AbstractHandler
10+
from mangum.response import Response
11+
from mangum.types import ASGIApp, ScopeDict
1012
from mangum.protocols.lifespan import LifespanCycle
1113
from mangum.protocols.http import HTTPCycle
1214
from mangum.exceptions import ConfigurationError
1315

14-
if typing.TYPE_CHECKING: # pragma: no cover
16+
if TYPE_CHECKING: # pragma: no cover
1517
from awslambdaric.lambda_context import LambdaContext
1618

1719
DEFAULT_TEXT_MIME_TYPES = [
20+
"text/",
1821
"application/json",
1922
"application/javascript",
2023
"application/xml",
2124
"application/vnd.api+json",
2225
]
2326

24-
LOG_LEVELS = {
25-
"critical": logging.CRITICAL,
26-
"error": logging.ERROR,
27-
"warning": logging.WARNING,
28-
"info": logging.INFO,
29-
"debug": logging.DEBUG,
30-
}
27+
logger = logging.getLogger("mangum")
3128

3229

33-
@dataclass
3430
class Mangum:
3531
"""
3632
Creates an adapter instance.
@@ -41,153 +37,40 @@ class Mangum:
4137
and `off`. Default is `auto`.
4238
* **log_level** - A string to configure the log level. Choices are: `info`,
4339
`critical`, `error`, `warning`, and `debug`. Default is `info`.
44-
* **api_gateway_base_path** - Base path to strip from URL when using a custom
45-
domain name.
4640
* **text_mime_types** - A list of MIME types to include with the defaults that
4741
should not return a binary response in API Gateway.
4842
"""
4943

5044
app: ASGIApp
5145
lifespan: str = "auto"
52-
log_level: str = "info"
53-
api_gateway_base_path: typing.Optional[str] = None
54-
text_mime_types: InitVar[typing.Optional[typing.List[str]]] = None
5546

56-
def __post_init__(self, text_mime_types: typing.Optional[typing.List[str]]) -> None:
47+
def __init__(
48+
self,
49+
app: ASGIApp,
50+
lifespan: str = "auto",
51+
**handler_kwargs,
52+
):
53+
self.app = app
54+
self.lifespan = lifespan
55+
self.handler_kwargs = handler_kwargs
56+
5757
if self.lifespan not in ("auto", "on", "off"):
5858
raise ConfigurationError(
5959
"Invalid argument supplied for `lifespan`. Choices are: auto|on|off"
6060
)
6161

62-
if self.log_level not in ("critical", "error", "warning", "info", "debug"):
63-
raise ConfigurationError(
64-
"Invalid argument supplied for `log_level`. "
65-
"Choices are: critical|error|warning|info|debug"
66-
)
67-
68-
self.logger = logging.getLogger("mangum")
69-
self.logger.setLevel(LOG_LEVELS[self.log_level])
70-
71-
should_prefix_base_path = (
72-
self.api_gateway_base_path
73-
and not self.api_gateway_base_path.startswith("/")
74-
)
75-
if should_prefix_base_path:
76-
self.api_gateway_base_path = f"/{self.api_gateway_base_path}"
77-
78-
if text_mime_types:
79-
text_mime_types += DEFAULT_TEXT_MIME_TYPES
80-
else:
81-
text_mime_types = DEFAULT_TEXT_MIME_TYPES
82-
self.text_mime_types = text_mime_types
83-
8462
def __call__(self, event: dict, context: "LambdaContext") -> dict:
85-
self.logger.debug("Event received.")
63+
logger.debug("Event received.")
8664

8765
with ExitStack() as stack:
8866
if self.lifespan != "off":
89-
lifespan_cycle: typing.ContextManager = LifespanCycle(
90-
self.app, self.lifespan
91-
)
67+
lifespan_cycle: ContextManager = LifespanCycle(self.app, self.lifespan)
9268
stack.enter_context(lifespan_cycle)
9369

94-
is_binary = event.get("isBase64Encoded", False)
95-
initial_body = event.get("body") or b""
96-
if is_binary:
97-
initial_body = base64.b64decode(initial_body)
98-
elif not isinstance(initial_body, bytes):
99-
initial_body = initial_body.encode()
100-
101-
scope = self.create_scope(event, context)
102-
http_cycle = HTTPCycle(scope, text_mime_types=self.text_mime_types)
103-
response = http_cycle(self.app, initial_body)
104-
105-
return response
106-
107-
def create_scope(self, event: dict, context: "LambdaContext") -> Scope:
108-
"""
109-
Creates a scope object according to ASGI specification from a Lambda Event.
110-
111-
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
112-
113-
The event comes from various sources: AWS ALB, AWS API Gateway of different
114-
versions and configurations(multivalue header, etc).
115-
Thus, some heuristics is applied to guess an event type.
116-
117-
"""
118-
request_context = event["requestContext"]
119-
120-
if event.get("multiValueHeaders"):
121-
headers = {
122-
k.lower(): ", ".join(v) if isinstance(v, list) else ""
123-
for k, v in event.get("multiValueHeaders", {}).items()
124-
}
125-
elif event.get("headers"):
126-
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
127-
else:
128-
headers = {}
129-
130-
# API Gateway v2
131-
if event.get("version") == "2.0":
132-
source_ip = request_context["http"]["sourceIp"]
133-
path = request_context["http"]["path"]
134-
http_method = request_context["http"]["method"]
135-
query_string = event.get("rawQueryString", "").encode()
136-
137-
if event.get("cookies"):
138-
headers["cookie"] = "; ".join(event.get("cookies", []))
139-
140-
# API Gateway v1 / ELB
141-
else:
142-
if "elb" in request_context:
143-
# NOTE: trust only the most right side value
144-
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
145-
else:
146-
source_ip = request_context.get("identity", {}).get("sourceIp")
147-
148-
path = event["path"]
149-
http_method = event["httpMethod"]
150-
151-
if event.get("multiValueQueryStringParameters"):
152-
query_string = urllib.parse.urlencode(
153-
event.get("multiValueQueryStringParameters", {}), doseq=True
154-
).encode()
155-
elif event.get("queryStringParameters"):
156-
query_string = urllib.parse.urlencode(
157-
event.get("queryStringParameters", {})
158-
).encode()
159-
else:
160-
query_string = b""
161-
162-
server_name = headers.get("host", "mangum")
163-
if ":" not in server_name:
164-
server_port = headers.get("x-forwarded-port", 80)
165-
else:
166-
server_name, server_port = server_name.split(":") # pragma: no cover
167-
server = (server_name, int(server_port))
168-
client = (source_ip, 0)
169-
170-
if not path: # pragma: no cover
171-
path = "/"
172-
elif self.api_gateway_base_path:
173-
if path.startswith(self.api_gateway_base_path):
174-
path = path[len(self.api_gateway_base_path) :]
175-
176-
scope = {
177-
"type": "http",
178-
"http_version": "1.1",
179-
"method": http_method,
180-
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
181-
"path": urllib.parse.unquote(path),
182-
"raw_path": None,
183-
"root_path": "",
184-
"scheme": headers.get("x-forwarded-proto", "https"),
185-
"query_string": query_string,
186-
"server": server,
187-
"client": client,
188-
"asgi": {"version": "3.0"},
189-
"aws.event": event,
190-
"aws.context": context,
191-
}
70+
handler = AbstractHandler.from_trigger(
71+
event, context, **self.handler_kwargs
72+
)
73+
http_cycle = HTTPCycle(handler.scope.as_dict())
74+
response = http_cycle(self.app, handler.body)
19275

193-
return scope
76+
return handler.transform_response(response)

mangum/handlers/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .abstract_handler import AbstractHandler
2+
from .aws_alb import AwsAlb
3+
from .aws_api_gateway import AwsApiGateway
4+
from .aws_cf_lambda_at_edge import AwsCfLambdaAtEdge
5+
from .aws_http_gateway import AwsHttpGateway
6+
7+
__all__ = [
8+
"AbstractHandler",
9+
"AwsAlb",
10+
"AwsApiGateway",
11+
"AwsCfLambdaAtEdge",
12+
"AwsHttpGateway",
13+
]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import base64
2+
from abc import ABCMeta, abstractmethod
3+
4+
from typing import Dict, Any, TYPE_CHECKING, Tuple, List
5+
6+
from mangum.response import Response
7+
from mangum.scope import Scope
8+
9+
if TYPE_CHECKING: # pragma: no cover
10+
from awslambdaric.lambda_context import LambdaContext
11+
12+
13+
class AbstractHandler(metaclass=ABCMeta):
14+
def __init__(self, trigger_event: Dict[str, Any], trigger_context: "LambdaContext"):
15+
self.trigger_event = trigger_event
16+
self.trigger_context = trigger_context
17+
18+
@property
19+
@abstractmethod
20+
def scope(self) -> Scope:
21+
"""
22+
Parse an ASGI scope from the request event
23+
"""
24+
25+
@property
26+
@abstractmethod
27+
def body(self) -> bytes:
28+
"""
29+
Get the raw body from the request event
30+
"""
31+
32+
@abstractmethod
33+
def transform_response(self, response: Response) -> Dict[str, Any]:
34+
"""
35+
After running our application, transform the response to the correct format for this handler
36+
"""
37+
38+
@staticmethod
39+
def from_trigger(
40+
trigger_event: Dict[str, Any], trigger_context: "LambdaContext", **kwargs
41+
):
42+
"""
43+
A factory method that determines which handler to use. All this code should probably stay in one place to make
44+
sure we are able to uniquely find each handler correctly.
45+
"""
46+
47+
# These should be ordered from most specific to least for best accuracy
48+
if (
49+
"requestContext" in trigger_event
50+
and "elb" in trigger_event["requestContext"]
51+
):
52+
from . import AwsAlb
53+
54+
return AwsAlb(trigger_event, trigger_context, **kwargs)
55+
56+
if (
57+
"Records" in trigger_event
58+
and len(trigger_event["Records"]) > 0
59+
and "cf" in trigger_event["Records"][0]
60+
):
61+
from . import AwsCfLambdaAtEdge
62+
63+
return AwsCfLambdaAtEdge(trigger_event, trigger_context, **kwargs)
64+
65+
if "version" in trigger_event and "requestContext" in trigger_event:
66+
from . import AwsHttpGateway
67+
68+
return AwsHttpGateway(trigger_event, trigger_context, **kwargs)
69+
70+
if "resource" in trigger_event:
71+
from . import AwsApiGateway
72+
73+
return AwsApiGateway(trigger_event, trigger_context, **kwargs)
74+
75+
raise TypeError("Unable to determine handler from trigger event")
76+
77+
@staticmethod
78+
def _handle_multi_value_headers(
79+
response_headers,
80+
) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
81+
headers: Dict[str, str] = {}
82+
multi_value_headers: Dict[str, List[str]] = {}
83+
for key, value in response_headers:
84+
lower_key = key.decode().lower()
85+
if lower_key in multi_value_headers:
86+
multi_value_headers[lower_key].append(value.decode())
87+
elif lower_key in headers:
88+
# Move existing to multi_value_headers and append current
89+
multi_value_headers[lower_key] = [
90+
headers[lower_key],
91+
value.decode(),
92+
]
93+
del headers[lower_key]
94+
else:
95+
headers[lower_key] = value.decode()
96+
return headers, multi_value_headers
97+
98+
@staticmethod
99+
def _handle_base64_response_body(
100+
body: bytes, headers: Dict[str, str]
101+
) -> Tuple[str, bool]:
102+
"""
103+
To ease debugging for our users, try and return strings where we can, otherwise to ensure maximum
104+
compatibility with binary data, base64 encode it.
105+
"""
106+
is_base64_encoded = False
107+
if body != b"":
108+
from ..adapter import DEFAULT_TEXT_MIME_TYPES
109+
110+
for text_mime_type in DEFAULT_TEXT_MIME_TYPES:
111+
if text_mime_type in headers.get("content-type", ""):
112+
try:
113+
body = body.decode()
114+
except UnicodeDecodeError:
115+
# Can't decode it, base64 it and be done
116+
body = base64.b64encode(body).decode()
117+
is_base64_encoded = True
118+
break
119+
else:
120+
# Not text, base64 encode
121+
body = base64.b64encode(body).decode()
122+
is_base64_encoded = True
123+
124+
return body, is_base64_encoded

0 commit comments

Comments
 (0)