Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .coveragerc

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v2
with:
version: "0.4.12"
version: "0.9.18"
enable-cache: true

- name: Set up Python
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v2
with:
version: "0.4.12"
version: "0.9.18"
enable-cache: true

- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v2
with:
version: "0.4.12"
version: "0.9.18"
enable-cache: true

- name: Set up Python
Expand Down
91 changes: 91 additions & 0 deletions mangum/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import asyncio
import sys
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar

__all__ = ["asyncio_run", "iscoroutinefunction"]

if sys.version_info >= (3, 14):
from inspect import iscoroutinefunction
else:
from asyncio import iscoroutinefunction

_T = TypeVar("_T")

if sys.version_info >= (3, 12):
asyncio_run = asyncio.run
elif sys.version_info >= (3, 11):

def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
# asyncio.run from Python 3.12
# https://docs.python.org/3/license.html#psf-license
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(main)

else:
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
# https://docs.python.org/3/license.html#psf-license
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError("asyncio.run() cannot be called from a running event loop")

if not asyncio.iscoroutine(main):
raise ValueError(f"a coroutine was expected, got {main!r}")

if loop_factory is None:
loop = asyncio.new_event_loop()
else:
loop = loop_factory()
try:
if loop_factory is None:
asyncio.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if loop_factory is None:
asyncio.set_event_loop(None)
loop.close()

def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
27 changes: 15 additions & 12 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import logging
from contextlib import ExitStack
from itertools import chain
from typing import Any

from mangum._compat import asyncio_run
from mangum.exceptions import ConfigurationError
from mangum.handlers import ALB, APIGateway, HTTPGateway, LambdaAtEdge
from mangum.protocols import HTTPCycle, LifespanCycle
Expand Down Expand Up @@ -59,17 +59,20 @@ def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
)

def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict[str, Any]:
handler = self.infer(event, context)
scope = handler.scope
with ExitStack() as stack:
async def handle_request() -> dict[str, Any]:
handler = self.infer(event, context)
scope = handler.scope

if self.lifespan in ("auto", "on"):
lifespan_cycle = LifespanCycle(self.app, self.lifespan)
stack.enter_context(lifespan_cycle)
scope.update({"state": lifespan_cycle.lifespan_state.copy()})

http_cycle = HTTPCycle(scope, handler.body)
http_response = http_cycle(self.app)

return handler(http_response)
async with lifespan_cycle:
scope.update({"state": lifespan_cycle.lifespan_state.copy()})
http_cycle = HTTPCycle(scope, handler.body)
http_response = await http_cycle(self.app)
return handler(http_response)
else:
http_cycle = HTTPCycle(scope, handler.body)
http_response = await http_cycle(self.app)
return handler(http_response)

assert False, "unreachable" # pragma: no cover
return asyncio_run(handle_request())
22 changes: 4 additions & 18 deletions mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,11 @@ def __init__(self, scope: Scope, body: bytes) -> None:
self.state = HTTPCycleState.REQUEST
self.logger = logging.getLogger("mangum.http")
self.app_queue: asyncio.Queue[Message] = asyncio.Queue()
self.app_queue.put_nowait(
{
"type": "http.request",
"body": body,
"more_body": False,
}
)
self.app_queue.put_nowait({"type": "http.request", "body": body, "more_body": False})

def __call__(self, app: ASGI) -> Response:
asgi_instance = self.run(app)
loop = asyncio.get_event_loop()
asgi_task = loop.create_task(asgi_instance)
loop.run_until_complete(asgi_task)

return {
"status": self.status,
"headers": self.headers,
"body": self.body,
}
async def __call__(self, app: ASGI) -> Response:
await self.run(app)
return {"status": self.status, "headers": self.headers, "body": self.body}

async def run(self, app: ASGI) -> None:
try:
Expand Down
12 changes: 6 additions & 6 deletions mangum/protocols/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,26 @@ def __init__(self, app: ASGI, lifespan: LifespanMode) -> None:
self.lifespan = lifespan
self.state: LifespanCycleState = LifespanCycleState.CONNECTING
self.exception: BaseException | None = None
self.loop = asyncio.get_event_loop()
self.app_queue: asyncio.Queue[Message] = asyncio.Queue()
self.startup_event: asyncio.Event = asyncio.Event()
self.shutdown_event: asyncio.Event = asyncio.Event()
self.logger = logging.getLogger("mangum.lifespan")
self.lifespan_state: dict[str, Any] = {}

def __enter__(self) -> None:
async def __aenter__(self) -> LifespanCycle:
"""Runs the event loop for application startup."""
self.loop.create_task(self.run())
self.loop.run_until_complete(self.startup())
asyncio.create_task(self.run())
await self.startup()
return self

def __exit__(
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Runs the event loop for application shutdown."""
self.loop.run_until_complete(self.shutdown())
await self.shutdown()

async def run(self) -> None:
"""Calls the application with the `lifespan` connection scope."""
Expand Down
18 changes: 11 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Internet :: WWW/HTTP",
]
dependencies = ["typing_extensions"]

[tool.uv]
default-groups = ["dev"]

[dependency-groups]
dev = [
"pytest",
"pytest-cov",
"pytest>=8.0.0",
"coverage",
"ruff",
"starlette",
"quart",
"quart>=0.20.0",
"hypercorn>=0.15.0",
"mypy",
"brotli",
Expand Down Expand Up @@ -66,10 +70,6 @@ ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/
strict = true

[tool.pytest.ini_options]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
addopts = "-rXs --strict-config --strict-markers"
xfail_strict = true
filterwarnings = [
Expand All @@ -86,8 +86,12 @@ filterwarnings = [

[tool.coverage.run]
source_pkgs = ["mangum", "tests"]
omit = ["mangum/_compat.py"]

[tool.coverage.report]
fail_under = 100
skip_covered = true
show_missing = true
exclude_lines = [
"pragma: no cover",
"pragma: nocover",
Expand Down
3 changes: 2 additions & 1 deletion scripts/test
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

set -x # print executed commands to the terminal

uv run pytest --ignore venv --cov=mangum --cov=tests --cov-fail-under=100 --cov-report=term-missing "${@}"
uv run coverage run -m pytest "${@}"
uv run coverage report
70 changes: 0 additions & 70 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,76 +176,6 @@ def mock_http_api_event_v1(request: pytest.FixtureRequest):
return event


@pytest.fixture
def mock_lambda_at_edge_event(request: pytest.FixtureRequest):
method = request.param[0]
path = request.param[1]
query_string = request.param[2]
body = request.param[3]

headers_raw = {
"accept-encoding": "gzip,deflate",
"x-forwarded-port": "443",
"x-forwarded-for": "192.168.100.1",
"x-forwarded-proto": "https",
"host": "test.execute-api.us-west-2.amazonaws.com",
}
headers = {}
for key, value in headers_raw.items():
headers[key.lower()] = [{"key": key, "value": value}]

event = {
"Records": [
{
"cf": {
"config": {
"distributionDomainName": "mock-distribution.local.localhost",
"distributionId": "ABC123DEF456G",
"eventType": "origin-request",
"requestId": "lBEBo2N0JKYUP2JXwn_4am2xAXB2GzcL2FlwXI8G59PA8wghF2ImFQ==",
},
"request": {
"clientIp": "192.168.100.1",
"headers": headers,
"method": method,
"origin": {
"custom": {
"customHeaders": {
"x-lae-env-custom-var": [
{
"key": "x-lae-env-custom-var",
"value": "environment variable",
}
],
},
"domainName": "www.example.com",
"keepaliveTimeout": 5,
"path": "",
"port": 80,
"protocol": "http",
"readTimeout": 30,
"sslProtocols": ["TLSv1", "TLSv1.1", "TLSv1.2"],
}
},
"querystring": query_string,
"uri": path,
},
}
}
]
}

if body is not None:
event["Records"][0]["cf"]["request"]["body"] = {
"inputTruncated": False,
"action": "read-only",
"encoding": "text",
"data": body,
}

return dict(method=method, path=path, query_string=query_string, body=body, event=event)


@pytest.fixture(scope="session", autouse=True)
def aws_credentials():
"""Mocked AWS Credentials for moto."""
Expand Down
18 changes: 18 additions & 0 deletions tests/handlers/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,21 @@ def test_custom_handler():
"server": ("mangum", 8080),
"type": "http",
}


def test_custom_handler_infer():
"""Test the infer method of CustomHandler."""
event_with_key = {"my-custom-key": 1}
event_without_key = {"other-key": 1}

assert CustomHandler.infer(event_with_key, {}, {"api_gateway_base_path": "/"}) is True
assert CustomHandler.infer(event_without_key, {}, {"api_gateway_base_path": "/"}) is False


def test_custom_handler_call():
"""Test the __call__ method of CustomHandler."""
event = {"my-custom-key": 1}
handler = CustomHandler(event, {}, {"api_gateway_base_path": "/"})

result = handler(status=200, headers=[], body=b"Hello, World!")
assert result == {"statusCode": 200, "headers": {}, "body": "Hello, World!"}
Loading
Loading