Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,7 @@ venv.bak/

# IDE Settings
.idea/
.vscode
.devcontainer

.DS_Store
2 changes: 2 additions & 0 deletions docs/adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ handler = Mangum(
app,
lifespan="auto",
api_gateway_base_path=None,
custom_handlers=None,
text_mime_types=None,
)
```

Expand Down
34 changes: 21 additions & 13 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
LambdaAtEdge,
]

DEFAULT_TEXT_MIME_TYPES: List[str] = [
"text/",
"application/json",
"application/javascript",
"application/xml",
"application/vnd.api+json",
"application/vnd.oai.openapi",
]


class Mangum:
def __init__(
Expand All @@ -34,6 +43,7 @@ def __init__(
lifespan: LifespanMode = "auto",
api_gateway_base_path: str = "/",
custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
text_mime_types: Optional[List[str]] = None,
) -> None:
if lifespan not in ("auto", "on", "off"):
raise ConfigurationError(
Expand All @@ -42,24 +52,22 @@ def __init__(

self.app = app
self.lifespan = lifespan
self.api_gateway_base_path = api_gateway_base_path or "/"
self.config = LambdaConfig(api_gateway_base_path=self.api_gateway_base_path)
self.custom_handlers = custom_handlers or []
self.config = LambdaConfig(
api_gateway_base_path=api_gateway_base_path or "/",
text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES],
)

def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
for handler_cls in chain(self.custom_handlers, HANDLERS):
if handler_cls.infer(event, context, self.config):
handler = handler_cls(event, context, self.config)
break
else:
raise RuntimeError( # pragma: no cover
"The adapter was unable to infer a handler to use for the event. This "
"is likely related to how the Lambda function was invoked. (Are you "
"testing locally? Make sure the request payload is valid for a "
"supported handler.)"
)

return handler
return handler_cls(event, context, self.config)
raise RuntimeError( # pragma: no cover
"The adapter was unable to infer a handler to use for the event. This "
"is likely related to how the Lambda function was invoked. (Are you "
"testing locally? Make sure the request payload is valid for a "
"supported handler.)"
)

def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict:
handler = self.infer(event, context)
Expand Down
2 changes: 1 addition & 1 deletion mangum/handlers/alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __call__(self, response: Response) -> dict:

finalized_headers = case_mutated_headers(multi_value_headers)
finalized_body, is_base64_encoded = handle_base64_response_body(
response["body"], finalized_headers
response["body"], finalized_headers, self.config["text_mime_types"]
)

out = {
Expand Down
6 changes: 3 additions & 3 deletions mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __call__(self, response: Response) -> dict:
response["headers"]
)
finalized_body, is_base64_encoded = handle_base64_response_body(
response["body"], finalized_headers
response["body"], finalized_headers, self.config["text_mime_types"]
)

return {
Expand Down Expand Up @@ -204,7 +204,7 @@ def __call__(self, response: Response) -> dict:
finalized_headers["content-type"] = "application/json"

finalized_body, is_base64_encoded = handle_base64_response_body(
response["body"], finalized_headers
response["body"], finalized_headers, self.config["text_mime_types"]
)
response_out = {
"statusCode": response["status"],
Expand All @@ -221,7 +221,7 @@ def __call__(self, response: Response) -> dict:
response["headers"]
)
finalized_body, is_base64_encoded = handle_base64_response_body(
response["body"], finalized_headers
response["body"], finalized_headers, self.config["text_mime_types"]
)
return {
"statusCode": response["status"],
Expand Down
2 changes: 1 addition & 1 deletion mangum/handlers/lambda_at_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def scope(self) -> Scope:
def __call__(self, response: Response) -> dict:
multi_value_headers, _ = handle_multi_value_headers(response["headers"])
response_body, is_base64_encoded = handle_base64_response_body(
response["body"], multi_value_headers
response["body"], multi_value_headers, self.config["text_mime_types"]
)
finalized_headers: Dict[str, List[Dict[str, str]]] = {
key.decode().lower(): [{"key": key.decode().lower(), "value": val.decode()}]
Expand Down
16 changes: 4 additions & 12 deletions mangum/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@
from mangum.types import Headers


DEFAULT_TEXT_MIME_TYPES = [
"text/",
"application/json",
"application/javascript",
"application/xml",
"application/vnd.api+json",
"application/vnd.oai.openapi",
]


def maybe_encode_body(body: Union[str, bytes], *, is_base64: bool) -> bytes:
body = body or b""
if is_base64:
Expand Down Expand Up @@ -71,12 +61,14 @@ def handle_multi_value_headers(


def handle_base64_response_body(
body: bytes, headers: Dict[str, str]
body: bytes,
headers: Dict[str, str],
text_mime_types: List[str],
) -> Tuple[str, bool]:
is_base64_encoded = False
output_body = ""
if body != b"":
for text_mime_type in DEFAULT_TEXT_MIME_TYPES:
for text_mime_type in text_mime_types:
if text_mime_type in headers.get("content-type", ""):
try:
output_body = body.decode()
Expand Down
1 change: 1 addition & 0 deletions mangum/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class Response(TypedDict):

class LambdaConfig(TypedDict):
api_gateway_base_path: str
text_mime_types: List[str]


class LambdaHandler(Protocol):
Expand Down
41 changes: 41 additions & 0 deletions tests/handlers/test_alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,44 @@ async def app(scope, receive, send):
"headers": {"content-type": content_type.decode()},
"body": res_body,
}


def test_aws_alb_response_extra_mime_types():
content_type = b"application/x-yaml"
utf_res_body = "name: 'John Doe'"
raw_res_body = utf_res_body.encode()
b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw=="

async def app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", content_type]],
}
)
await send({"type": "http.response.body", "body": raw_res_body})

event = get_mock_aws_alb_event("GET", "/test", {}, None, None, False, False)

# Test default behavior
handler = Mangum(app, lifespan="off")
response = handler(event, {})
assert content_type.decode() not in handler.config["text_mime_types"]
assert response == {
"statusCode": 200,
"isBase64Encoded": True,
"headers": {"content-type": content_type.decode()},
"body": b64_res_body,
}

# Test with modified text mime types
handler = Mangum(app, lifespan="off")
handler.config["text_mime_types"].append(content_type.decode())
response = handler(event, {})
assert response == {
"statusCode": 200,
"isBase64Encoded": False,
"headers": {"content-type": content_type.decode()},
"body": utf_res_body,
}
43 changes: 43 additions & 0 deletions tests/handlers/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,46 @@ async def app(scope, receive, send):
"multiValueHeaders": {},
"body": res_body,
}


def test_aws_api_gateway_response_extra_mime_types():
content_type = b"application/x-yaml"
utf_res_body = "name: 'John Doe'"
raw_res_body = utf_res_body.encode()
b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw=="

async def app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", content_type]],
}
)
await send({"type": "http.response.body", "body": raw_res_body})

event = get_mock_aws_api_gateway_event("POST", "/test", {}, None, False)

# Test default behavior
handler = Mangum(app, lifespan="off")
response = handler(event, {})
assert content_type.decode() not in handler.config["text_mime_types"]
assert response == {
"statusCode": 200,
"isBase64Encoded": True,
"headers": {"content-type": content_type.decode()},
"multiValueHeaders": {},
"body": b64_res_body,
}

# Test with modified text mime types
handler = Mangum(app, lifespan="off")
handler.config["text_mime_types"].append(content_type.decode())
response = handler(event, {})
assert response == {
"statusCode": 200,
"isBase64Encoded": False,
"headers": {"content-type": content_type.decode()},
"multiValueHeaders": {},
"body": utf_res_body,
}
92 changes: 92 additions & 0 deletions tests/handlers/test_http_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,95 @@ async def app(scope, receive, send):
"headers": {"content-type": content_type.decode()},
"body": res_body,
}


def test_aws_http_gateway_response_v1_extra_mime_types():
content_type = b"application/x-yaml"
utf_res_body = "name: 'John Doe'"
raw_res_body = utf_res_body.encode()
b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw=="

async def app(scope, receive, send):
headers = []
if content_type is not None:
headers.append([b"content-type", content_type])

await send(
{
"type": "http.response.start",
"status": 200,
"headers": headers,
}
)
await send({"type": "http.response.body", "body": raw_res_body})

event = get_mock_aws_http_gateway_event_v1("POST", "/test", {}, None, False)

# Test default behavior
handler = Mangum(app, lifespan="off")
response = handler(event, {})
assert content_type.decode() not in handler.config["text_mime_types"]
assert response == {
"statusCode": 200,
"isBase64Encoded": True,
"headers": {"content-type": content_type.decode()},
"multiValueHeaders": {},
"body": b64_res_body,
}

# Test with modified text mime types
handler = Mangum(app, lifespan="off")
handler.config["text_mime_types"].append(content_type.decode())
response = handler(event, {})
assert response == {
"statusCode": 200,
"isBase64Encoded": False,
"headers": {"content-type": content_type.decode()},
"multiValueHeaders": {},
"body": utf_res_body,
}


def test_aws_http_gateway_response_v2_extra_mime_types():
content_type = b"application/x-yaml"
utf_res_body = "name: 'John Doe'"
raw_res_body = utf_res_body.encode()
b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw=="

async def app(scope, receive, send):
headers = []
if content_type is not None:
headers.append([b"content-type", content_type])

await send(
{
"type": "http.response.start",
"status": 200,
"headers": headers,
}
)
await send({"type": "http.response.body", "body": raw_res_body})

event = get_mock_aws_http_gateway_event_v2("POST", "/test", {}, None, False)

# Test default behavior
handler = Mangum(app, lifespan="off")
response = handler(event, {})
assert content_type.decode() not in handler.config["text_mime_types"]
assert response == {
"statusCode": 200,
"isBase64Encoded": True,
"headers": {"content-type": content_type.decode()},
"body": b64_res_body,
}

# Test with modified text mime types
handler = Mangum(app, lifespan="off")
handler.config["text_mime_types"].append(content_type.decode())
response = handler(event, {})
assert response == {
"statusCode": 200,
"isBase64Encoded": False,
"headers": {"content-type": content_type.decode()},
"body": utf_res_body,
}
Loading