Skip to content
16 changes: 9 additions & 7 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from contextlib import ExitStack
from typing import (
Any,
ContextManager,
Dict,
Optional,
TYPE_CHECKING,
Expand Down Expand Up @@ -44,15 +43,18 @@ class Mangum:
* **text_mime_types** - A list of MIME types to include with the defaults that
should not return a binary response in API Gateway.
* **dsn** - A connection string required to configure a supported WebSocket backend.
* **api_gateway_base_path** - A string specifying the part of the url path after
which the server routing begins.
* **api_gateway_endpoint_url** - A string endpoint url to use for API Gateway when
sending data to WebSocket connections. Default is to determine this automatically.
* **api_gateway_region_name** - A string region name to use for API Gateway when
sending data to WebSocket connections. Default is `AWS_REGION` environment variable.
"""

app: ASGIApp
lifespan: str = "auto"
lifespan: str
dsn: Optional[str] = None
api_gateway_base_path: str
api_gateway_endpoint_url: Optional[str] = None
api_gateway_region_name: Optional[str] = None
Comment thread
relsunkaev marked this conversation as resolved.
Outdated

Expand All @@ -61,32 +63,32 @@ def __init__(
app: ASGIApp,
lifespan: str = "auto",
dsn: Optional[str] = None,
api_gateway_base_path: str = "/",
api_gateway_endpoint_url: Optional[str] = None,
api_gateway_region_name: Optional[str] = None,
**handler_kwargs: Dict[str, Any]
) -> None:
self.app = app
self.lifespan = lifespan
self.dsn = dsn
self.api_gateway_base_path = api_gateway_base_path
self.api_gateway_endpoint_url = api_gateway_endpoint_url
self.api_gateway_region_name = api_gateway_region_name
self.handler_kwargs = handler_kwargs

if self.lifespan not in ("auto", "on", "off"):
raise ConfigurationError(
"Invalid argument supplied for `lifespan`. Choices are: auto|on|off"
)

def __call__(self, event: dict, context: "LambdaContext") -> dict:
def __call__(self, event: Dict[str, Any], context: "LambdaContext") -> dict:
logger.debug("Event received.")

with ExitStack() as stack:
if self.lifespan != "off":
lifespan_cycle: ContextManager = LifespanCycle(self.app, self.lifespan)
lifespan_cycle = LifespanCycle(self.app, self.lifespan)
stack.enter_context(lifespan_cycle)

handler = AbstractHandler.from_trigger(
event, context, **self.handler_kwargs
event, context, self.api_gateway_base_path
)
request = handler.request

Expand Down
19 changes: 10 additions & 9 deletions mangum/handlers/abstract_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(
self,
trigger_event: Dict[str, Any],
trigger_context: "LambdaContext",
**kwargs: Dict[str, Any],
):
self.trigger_event = trigger_event
self.trigger_context = trigger_context
Expand Down Expand Up @@ -62,7 +61,7 @@ def api_gateway_endpoint_url(self) -> str:
def from_trigger(
trigger_event: Dict[str, Any],
trigger_context: "LambdaContext",
**kwargs: Dict[str, Any],
api_gateway_base_path: str = "/",
Comment thread
relsunkaev marked this conversation as resolved.
) -> "AbstractHandler":
"""
A factory method that determines which handler to use. All this code should
Expand All @@ -77,17 +76,15 @@ def from_trigger(
):
from . import AwsAlb

return AwsAlb(trigger_event, trigger_context, **kwargs)
return AwsAlb(trigger_event, trigger_context)

if (
"requestContext" in trigger_event
and "connectionId" in trigger_event["requestContext"]
):
from . import AwsWsGateway

return AwsWsGateway(
trigger_event, trigger_context, **kwargs # type: ignore
)
return AwsWsGateway(trigger_event, trigger_context)

if (
"Records" in trigger_event
Expand All @@ -96,20 +93,24 @@ def from_trigger(
):
from . import AwsCfLambdaAtEdge

return AwsCfLambdaAtEdge(trigger_event, trigger_context, **kwargs)
return AwsCfLambdaAtEdge(trigger_event, trigger_context)

if "version" in trigger_event and "requestContext" in trigger_event:
from . import AwsHttpGateway

return AwsHttpGateway(
trigger_event, trigger_context, **kwargs # type: ignore
trigger_event,
trigger_context,
api_gateway_base_path=api_gateway_base_path,
)

if "resource" in trigger_event:
from . import AwsApiGateway

return AwsApiGateway(
trigger_event, trigger_context, **kwargs # type: ignore
trigger_event,
trigger_context,
api_gateway_base_path=api_gateway_base_path,
)

raise TypeError("Unable to determine handler from trigger event")
Expand Down
14 changes: 8 additions & 6 deletions mangum/handlers/aws_alb.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import base64
import urllib.parse
from typing import Any, Dict, Generator, List, Tuple
from typing import Any, Dict, Generator, List, Tuple, Optional, Union
from itertools import islice

from .abstract_handler import AbstractHandler
from .. import Response, Request


def all_casings(input_string: str) -> Generator:
def all_casings(input_string: str) -> Generator[str, None, None]:
"""
Permute all casings of a given string.
A pretty algoritm, via @Amber
Expand All @@ -28,7 +28,7 @@ def all_casings(input_string: str) -> Generator:

def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]:
"""Create str/str key/value headers, with duplicate keys case mutated."""
headers = {}
headers: Dict[str, str] = {}
for key, values in multi_value_headers.items():
if len(values) > 0:
casings = list(islice(all_casings(key), len(values)))
Expand Down Expand Up @@ -62,7 +62,9 @@ def encode_query_string(self) -> bytes:
Issue: https://github.com/jordaneremieff/mangum/issues/178
"""

params = self.trigger_event.get("multiValueQueryStringParameters")
params: Optional[
Comment thread
relsunkaev marked this conversation as resolved.
Outdated
Dict[str, Union[str, Tuple[str, ...], List[str]]]
] = self.trigger_event.get("multiValueQueryStringParameters")
if not params:
params = self.trigger_event.get("queryStringParameters")
if not params:
Expand All @@ -71,7 +73,7 @@ def encode_query_string(self) -> bytes:
# Loop through the query parameters, unquote each key and value and append the
# pair as a tuple to the query list. If value is a list or a tuple, loop
# through the nested struture and unqote.
query = []
query: List[Tuple[str, str]] = []
for key, value in params.items():
if isinstance(value, (tuple, list)):
for v in value:
Expand All @@ -92,7 +94,7 @@ def transform_headers(self) -> List[Tuple[bytes, bytes]]:
trigger event. However, we act as though they both might exist and pull
headers out of both.
"""
headers = []
headers: List[Tuple[bytes, bytes]] = []
if "multiValueHeaders" in self.trigger_event:
for k, v in self.trigger_event["multiValueHeaders"].items():
for inner_v in v:
Expand Down
3 changes: 1 addition & 2 deletions mangum/handlers/aws_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ def __init__(
trigger_event: Dict[str, Any],
trigger_context: "LambdaContext",
api_gateway_base_path: str = "/",
**kwargs: Dict[str, Any], # type: ignore
):
super().__init__(trigger_event, trigger_context, **kwargs)
super().__init__(trigger_event, trigger_context)
self.api_gateway_base_path = api_gateway_base_path

@property
Expand Down
2 changes: 1 addition & 1 deletion mangum/handlers/aws_ws_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .abstract_handler import AbstractHandler


def get_server_and_headers(event: dict) -> Tuple: # pragma: no cover
def get_server_and_headers(event: Dict[str, Any]) -> Tuple: # pragma: no cover
if event.get("multiValueHeaders"):
headers = {
k.lower(): ", ".join(v) if isinstance(v, list) else ""
Expand Down
6 changes: 3 additions & 3 deletions mangum/protocols/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
import asyncio
from typing import Optional
import typing
Comment thread
relsunkaev marked this conversation as resolved.
Outdated
import logging
from io import BytesIO
from dataclasses import dataclass
Expand Down Expand Up @@ -46,12 +46,12 @@ class HTTPCycle:

request: Request
state: HTTPCycleState = HTTPCycleState.REQUEST
response: Optional[Response] = None
response: typing.Optional[Response] = None

def __post_init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("mangum.http")
self.loop = asyncio.get_event_loop()
self.app_queue: asyncio.Queue = asyncio.Queue()
self.app_queue: asyncio.Queue[typing.Dict[str, typing.Any]] = asyncio.Queue()
self.body: BytesIO = BytesIO()

def __call__(self, app: ASGIApp, initial_body: bytes) -> Response:
Expand Down
7 changes: 3 additions & 4 deletions mangum/protocols/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@ class LifespanCycle:
and `off`. Default is `auto`.
* **state** - An enumerated `LifespanCycleState` type that indicates the state of
the ASGI connection.
* **exception** - An exception raised while handling the ASGI event.
* **exception** - An exception raised while handling the ASGI event. This may or
may not be raised depending on the state.
* **app_queue** - An asyncio queue (FIFO) containing messages to be received by the
application.
* **startup_event** - An asyncio event object used to control the application
startup flow.
* **shutdown_event** - An asyncio event object used to control the application
shutdown flow.
* **exception** - An exception raised while handling the ASGI event. This may or
may not be raised depending on the state.
"""

app: ASGIApp
Expand All @@ -63,7 +62,7 @@ class LifespanCycle:
def __post_init__(self) -> None:
self.logger = logging.getLogger("mangum.lifespan")
self.loop = asyncio.get_event_loop()
self.app_queue: asyncio.Queue = asyncio.Queue()
self.app_queue: asyncio.Queue[typing.Dict[str, str]] = asyncio.Queue()
Comment thread
relsunkaev marked this conversation as resolved.
Outdated
self.startup_event: asyncio.Event = asyncio.Event()
self.shutdown_event: asyncio.Event = asyncio.Event()

Expand Down
6 changes: 4 additions & 2 deletions mangum/protocols/websockets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import enum
import asyncio
import copy
import typing
import logging
from io import BytesIO
from dataclasses import dataclass
Expand Down Expand Up @@ -61,7 +63,7 @@ class WebSocketCycle:
def __post_init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("mangum.websocket")
self.loop = asyncio.get_event_loop()
self.app_queue: asyncio.Queue = asyncio.Queue()
self.app_queue: asyncio.Queue[typing.Dict[str, typing.Any]] = asyncio.Queue()
self.body: BytesIO = BytesIO()
self.response: Response = Response(200, [], b"")

Expand Down Expand Up @@ -93,7 +95,7 @@ async def run(self, app: ASGIApp) -> None:
Calls the application with the `websocket` connection scope.
"""
self.scope = await self.websocket.on_message(self.connection_id)
scope = self.scope.copy() # type: ignore
scope = copy.copy(self.scope)
scope.update(
{
"aws.event": self.request.trigger_event,
Expand Down
5 changes: 3 additions & 2 deletions mangum/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class BaseRequest:
root_path: str = ""
asgi: Dict[str, str] = field(default_factory=lambda: {"version": "3.0"})

@property
def scope(self) -> Scope:
return {
"http_version": self.http_version,
Expand Down Expand Up @@ -85,7 +86,7 @@ class Request(BaseRequest):

@property
def scope(self) -> Scope:
scope = super().scope()
scope = super().scope
scope.update({"type": self.type, "method": self.method})
return scope

Expand All @@ -103,7 +104,7 @@ class WsRequest(BaseRequest):

@property
def scope(self) -> Scope:
scope = super().scope()
scope = super().scope
scope.update({"type": self.type, "subprotocols": self.subprotocols})
return scope

Expand Down
1 change: 1 addition & 0 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ async def app(scope, receive, send):
def test_default_settings():
handler = Mangum(app)
assert handler.lifespan == "auto"
assert handler.api_gateway_base_path == "/"


@pytest.mark.parametrize(
Expand Down