Skip to content

Commit e549e4b

Browse files
committed
Refactor handlers to be separate from core logic
Main tests passing, could use more coverage. Needs a look for regression
1 parent 18574b8 commit e549e4b

23 files changed

+2075
-980
lines changed

docs/http.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
# HTTP
22

3-
Mangum provides support for both [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) and the newer [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) APIs in API Gateway. It also includes configurable binary response support.
4-
3+
Mangum provides support for the following AWS HTTP Lambda Event Source:
4+
5+
* [API Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html)
6+
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/services-apigateway.html))
7+
* [HTTP Gateway](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html)
8+
([Event Examples](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html))
9+
* [Application Load Balancer (ALB)](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/introduction.html)
10+
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html))
11+
* [CloudFront Lambda@Edge](https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/lambda-at-the-edge.html)
12+
([Event Examples](https://docs.aws.amazon.com/lambda/latest/dg/lambda-edge.html))
13+
514
```python
615
from fastapi import FastAPI
716
from fastapi.middleware.gzip import GZipMiddleware

mangum/adapter.py

Lines changed: 21 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,32 @@
11
import base64
2-
import typing
2+
from typing import Any, Callable, ContextManager, Dict, Optional, List, TYPE_CHECKING
33
import logging
44
import urllib.parse
55

66
from dataclasses import dataclass, InitVar
77
from contextlib import ExitStack
88

9-
from mangum.types import ASGIApp, Scope
9+
from mangum.handlers import AbstractHandler
10+
from mangum.response import Response
11+
from mangum.types import ASGIApp, ScopeDict
1012
from mangum.protocols.lifespan import LifespanCycle
1113
from mangum.protocols.http import HTTPCycle
1214
from mangum.exceptions import ConfigurationError
1315

14-
if typing.TYPE_CHECKING: # pragma: no cover
16+
if TYPE_CHECKING: # pragma: no cover
1517
from awslambdaric.lambda_context import LambdaContext
1618

1719
DEFAULT_TEXT_MIME_TYPES = [
20+
"text/",
1821
"application/json",
1922
"application/javascript",
2023
"application/xml",
2124
"application/vnd.api+json",
2225
]
2326

24-
LOG_LEVELS = {
25-
"critical": logging.CRITICAL,
26-
"error": logging.ERROR,
27-
"warning": logging.WARNING,
28-
"info": logging.INFO,
29-
"debug": logging.DEBUG,
30-
}
27+
logger = logging.getLogger("mangum")
3128

3229

33-
@dataclass
3430
class Mangum:
3531
"""
3632
Creates an adapter instance.
@@ -41,152 +37,38 @@ class Mangum:
4137
and `off`. Default is `auto`.
4238
* **log_level** - A string to configure the log level. Choices are: `info`,
4339
`critical`, `error`, `warning`, and `debug`. Default is `info`.
44-
* **api_gateway_base_path** - Base path to strip from URL when using a custom
45-
domain name.
4640
* **text_mime_types** - A list of MIME types to include with the defaults that
4741
should not return a binary response in API Gateway.
4842
"""
4943

5044
app: ASGIApp
5145
lifespan: str = "auto"
52-
log_level: str = "info"
53-
api_gateway_base_path: typing.Optional[str] = None
54-
text_mime_types: InitVar[typing.Optional[typing.List[str]]] = None
5546

56-
def __post_init__(self, text_mime_types: typing.Optional[typing.List[str]]) -> None:
47+
def __init__(self, app: ASGIApp,
48+
lifespan: str = "auto",
49+
**handler_kwargs,
50+
):
51+
self.app = app
52+
self.lifespan = lifespan
53+
self.handler_kwargs = handler_kwargs
54+
5755
if self.lifespan not in ("auto", "on", "off"):
5856
raise ConfigurationError(
5957
"Invalid argument supplied for `lifespan`. Choices are: auto|on|off"
6058
)
6159

62-
if self.log_level not in ("critical", "error", "warning", "info", "debug"):
63-
raise ConfigurationError(
64-
"Invalid argument supplied for `log_level`. "
65-
"Choices are: critical|error|warning|info|debug"
66-
)
67-
68-
self.logger = logging.getLogger("mangum")
69-
self.logger.setLevel(LOG_LEVELS[self.log_level])
70-
71-
should_prefix_base_path = (
72-
self.api_gateway_base_path
73-
and not self.api_gateway_base_path.startswith("/")
74-
)
75-
if should_prefix_base_path:
76-
self.api_gateway_base_path = f"/{self.api_gateway_base_path}"
77-
78-
if text_mime_types:
79-
text_mime_types += DEFAULT_TEXT_MIME_TYPES
80-
else:
81-
text_mime_types = DEFAULT_TEXT_MIME_TYPES
82-
self.text_mime_types = text_mime_types
83-
8460
def __call__(self, event: dict, context: "LambdaContext") -> dict:
85-
self.logger.debug("Event received.")
61+
logger.debug("Event received.")
8662

8763
with ExitStack() as stack:
8864
if self.lifespan != "off":
89-
lifespan_cycle: typing.ContextManager = LifespanCycle(
65+
lifespan_cycle: ContextManager = LifespanCycle(
9066
self.app, self.lifespan
9167
)
9268
stack.enter_context(lifespan_cycle)
9369

94-
is_binary = event.get("isBase64Encoded", False)
95-
initial_body = event.get("body") or b""
96-
if is_binary:
97-
initial_body = base64.b64decode(initial_body)
98-
elif not isinstance(initial_body, bytes):
99-
initial_body = initial_body.encode()
100-
101-
scope = self._create_scope(event, context)
102-
http_cycle = HTTPCycle(scope, text_mime_types=self.text_mime_types)
103-
response = http_cycle(self.app, initial_body)
104-
105-
return response
106-
107-
def _create_scope(self, event: dict, context: "LambdaContext") -> Scope:
108-
"""Creates a scope object according to ASGI specification from a Lambda Event.
109-
110-
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
111-
112-
The event comes from various sources: AWS ALB, AWS API Gateway of different
113-
versions and configurations(multivalue header, etc).
114-
Thus, some heuristics is applied to guess an event type.
115-
116-
"""
117-
request_context = event["requestContext"]
118-
119-
if event.get("multiValueHeaders"):
120-
headers = {
121-
k.lower(): ", ".join(v) if isinstance(v, list) else ""
122-
for k, v in event.get("multiValueHeaders", {}).items()
123-
}
124-
elif event.get("headers"):
125-
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
126-
else:
127-
headers = {}
128-
129-
# API Gateway v2
130-
if event.get("version") == "2.0":
131-
source_ip = request_context["http"]["sourceIp"]
132-
path = request_context["http"]["path"]
133-
http_method = request_context["http"]["method"]
134-
query_string = event.get("rawQueryString", "").encode()
135-
136-
if event.get("cookies"):
137-
headers["cookie"] = "; ".join(event.get("cookies", []))
138-
139-
# API Gateway v1 / ELB
140-
else:
141-
if "elb" in request_context:
142-
# NOTE: trust only the most right side value
143-
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
144-
else:
145-
source_ip = request_context.get("identity", {}).get("sourceIp")
146-
147-
path = event["path"]
148-
http_method = event["httpMethod"]
149-
150-
if event.get("multiValueQueryStringParameters"):
151-
query_string = urllib.parse.urlencode(
152-
event.get("multiValueQueryStringParameters", {}), doseq=True
153-
).encode()
154-
elif event.get("queryStringParameters"):
155-
query_string = urllib.parse.urlencode(
156-
event.get("queryStringParameters", {})
157-
).encode()
158-
else:
159-
query_string = b""
160-
161-
server_name = headers.get("host", "mangum")
162-
if ":" not in server_name:
163-
server_port = headers.get("x-forwarded-port", 80)
164-
else:
165-
server_name, server_port = server_name.split(":") # pragma: no cover
166-
server = (server_name, int(server_port))
167-
client = (source_ip, 0)
168-
169-
if not path: # pragma: no cover
170-
path = "/"
171-
elif self.api_gateway_base_path:
172-
if path.startswith(self.api_gateway_base_path):
173-
path = path[len(self.api_gateway_base_path) :]
174-
175-
scope = {
176-
"type": "http",
177-
"http_version": "1.1",
178-
"method": http_method,
179-
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
180-
"path": urllib.parse.unquote(path),
181-
"raw_path": None,
182-
"root_path": "",
183-
"scheme": headers.get("x-forwarded-proto", "https"),
184-
"query_string": query_string,
185-
"server": server,
186-
"client": client,
187-
"asgi": {"version": "3.0"},
188-
"aws.event": event,
189-
"aws.context": context,
190-
}
70+
handler = AbstractHandler.from_trigger(event, context, **self.handler_kwargs)
71+
http_cycle = HTTPCycle(handler.scope.as_dict())
72+
response = http_cycle(self.app, handler.body)
19173

192-
return scope
74+
return handler.transform_response(response)

mangum/handlers/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .abstract_handler import AbstractHandler
2+
from .aws_alb import AwsAlb
3+
from .aws_api_gateway import AwsApiGateway
4+
from .aws_cf_lambda_at_edge import AwsCfLambdaAtEdge
5+
from .aws_http_gateway import AwsHttpGateway
6+
7+
__all__ = [
8+
"AbstractHandler",
9+
"AwsAlb",
10+
"AwsApiGateway",
11+
"AwsCfLambdaAtEdge",
12+
"AwsHttpGateway"
13+
]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import base64
2+
from abc import ABCMeta, abstractmethod
3+
4+
from typing import Dict, Any, TYPE_CHECKING, Tuple, List
5+
6+
from mangum.response import Response
7+
from mangum.scope import Scope
8+
9+
if TYPE_CHECKING: # pragma: no cover
10+
from awslambdaric.lambda_context import LambdaContext
11+
12+
13+
class AbstractHandler(metaclass=ABCMeta):
14+
15+
def __init__(self, trigger_event: Dict[str, Any], trigger_context: "LambdaContext"):
16+
self.trigger_event = trigger_event
17+
self.trigger_context = trigger_context
18+
19+
@property
20+
@abstractmethod
21+
def scope(self) -> Scope:
22+
"""
23+
Parse an ASGI scope from the request event
24+
"""
25+
26+
@property
27+
@abstractmethod
28+
def body(self) -> bytes:
29+
"""
30+
Get the raw body from the request event
31+
"""
32+
33+
@abstractmethod
34+
def transform_response(self, response: Response) -> Dict[str, Any]:
35+
"""
36+
After running our application, transform the response to the correct format for this handler
37+
"""
38+
39+
@staticmethod
40+
def from_trigger(trigger_event: Dict[str, Any], trigger_context: "LambdaContext", **kwargs):
41+
"""
42+
A factory method that determines which handler to use. All this code should probably stay in one place to make
43+
sure we are able to uniquely find each handler correctly.
44+
"""
45+
46+
# These should be ordered from most specific to least for best accuracy
47+
if "requestContext" in trigger_event and "elb" in trigger_event["requestContext"]:
48+
from . import AwsAlb
49+
return AwsAlb(trigger_event, trigger_context, **kwargs)
50+
51+
if "Records" in trigger_event and len(trigger_event["Records"]) > 0 and "cf" in trigger_event["Records"][0]:
52+
from . import AwsCfLambdaAtEdge
53+
return AwsCfLambdaAtEdge(trigger_event, trigger_context, **kwargs)
54+
55+
if "version" in trigger_event and "requestContext" in trigger_event:
56+
from . import AwsHttpGateway
57+
return AwsHttpGateway(trigger_event, trigger_context, **kwargs)
58+
59+
if "resource" in trigger_event:
60+
from . import AwsApiGateway
61+
return AwsApiGateway(trigger_event, trigger_context, **kwargs)
62+
63+
raise TypeError("Unable to determine handler from trigger event")
64+
65+
@staticmethod
66+
def _handle_multi_value_headers(response_headers) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
67+
headers: Dict[str, str] = {}
68+
multi_value_headers: Dict[str, List[str]] = {}
69+
for key, value in response_headers:
70+
lower_key = key.decode().lower()
71+
if lower_key in multi_value_headers:
72+
multi_value_headers[lower_key].append(value.decode())
73+
elif lower_key in headers:
74+
# Move existing to multi_value_headers and append current
75+
multi_value_headers[lower_key] = [
76+
headers[lower_key],
77+
value.decode(),
78+
]
79+
del headers[lower_key]
80+
else:
81+
headers[lower_key] = value.decode()
82+
return headers, multi_value_headers
83+
84+
@staticmethod
85+
def _handle_base64_response_body(body: bytes, headers: Dict[str, str]) -> Tuple[str, bool]:
86+
"""
87+
To ease debugging for our users, try and return strings where we can, otherwise to ensure maximum
88+
compatibility with binary data, base64 encode it.
89+
"""
90+
is_base64_encoded = False
91+
if body != b'':
92+
from ..adapter import DEFAULT_TEXT_MIME_TYPES
93+
for text_mime_type in DEFAULT_TEXT_MIME_TYPES:
94+
if text_mime_type in headers.get("content-type", ""):
95+
try:
96+
body = body.decode()
97+
except UnicodeDecodeError:
98+
# Can't decode it, base64 it and be done
99+
body = base64.b64encode(body).decode()
100+
is_base64_encoded = True
101+
break
102+
else:
103+
# Not text, base64 encode
104+
body = base64.b64encode(body).decode()
105+
is_base64_encoded = True
106+
107+
return body, is_base64_encoded

0 commit comments

Comments
 (0)