diff --git a/mangum/handlers/aws_alb.py b/mangum/handlers/aws_alb.py index b2c64709..33e0968f 100644 --- a/mangum/handlers/aws_alb.py +++ b/mangum/handlers/aws_alb.py @@ -57,9 +57,13 @@ def request(self) -> Request: @property def body(self) -> bytes: - body = self.trigger_event.get("body", b"") + body = self.trigger_event.get("body", b"") or b"" + if self.trigger_event.get("isBase64Encoded", False): - body = base64.b64decode(body) + return base64.b64decode(body) + if not isinstance(body, bytes): + body = body.encode() + return body def transform_response(self, response: Response) -> Dict[str, Any]: diff --git a/mangum/handlers/aws_api_gateway.py b/mangum/handlers/aws_api_gateway.py index 6267ffa3..2e239d44 100644 --- a/mangum/handlers/aws_api_gateway.py +++ b/mangum/handlers/aws_api_gateway.py @@ -95,9 +95,13 @@ def request(self) -> Request: @property def body(self) -> bytes: - body = self.trigger_event.get("body", b"") + body = self.trigger_event.get("body", b"") or b"" + if self.trigger_event.get("isBase64Encoded", False): - body = base64.b64decode(body) + return base64.b64decode(body) + if not isinstance(body, bytes): + body = body.encode() + return body def transform_response(self, response: Response) -> Dict[str, Any]: diff --git a/mangum/handlers/aws_cf_lambda_at_edge.py b/mangum/handlers/aws_cf_lambda_at_edge.py index 85013bd3..b43132c3 100644 --- a/mangum/handlers/aws_cf_lambda_at_edge.py +++ b/mangum/handlers/aws_cf_lambda_at_edge.py @@ -55,9 +55,13 @@ def request(self) -> Request: @property def body(self) -> bytes: request = self.trigger_event["Records"][0]["cf"]["request"] - body = request.get("body", {}).get("data", None) + body = request.get("body", {}).get("data", None) or b"" + if request.get("body", {}).get("encoding", "") == "base64": - body = base64.b64decode(body) + return base64.b64decode(body) + if not isinstance(body, bytes): + body = body.encode() + return body def transform_response(self, response: Response) -> Dict[str, Any]: diff --git a/mangum/handlers/aws_http_gateway.py b/mangum/handlers/aws_http_gateway.py index 6ff783cd..513c4942 100644 --- a/mangum/handlers/aws_http_gateway.py +++ b/mangum/handlers/aws_http_gateway.py @@ -102,9 +102,13 @@ def request(self) -> Request: @property def body(self) -> bytes: - body = self.trigger_event.get("body", b"") + body = self.trigger_event.get("body", b"") or b"" + if self.trigger_event.get("isBase64Encoded", False): - body = base64.b64decode(body) + return base64.b64decode(body) + if not isinstance(body, bytes): + body = body.encode() + return body def transform_response(self, response: Response) -> Dict[str, Any]: diff --git a/tests/handlers/test_aws_alb.py b/tests/handlers/test_aws_alb.py index 216d8a8d..d922c37d 100644 --- a/tests/handlers/test_aws_alb.py +++ b/tests/handlers/test_aws_alb.py @@ -74,6 +74,8 @@ def test_aws_alb_basic(): example_context = {} handler = AwsAlb(example_event, example_context) + + assert type(handler.body) == bytes assert handler.request.scope == { "asgi": {"version": "3.0"}, "aws.context": {}, @@ -119,7 +121,15 @@ def test_aws_alb_basic(): "query_string,scope_body", [ ("GET", "/hello/world", None, None, False, b"", None), - ("POST", "/", {"name": ["me"]}, None, False, b"name=me", None), + ( + "POST", + "/", + {"name": ["me"]}, + "field1=value1&field2=value2", + False, + b"name=me", + b"field1=value1&field2=value2", + ), ( "GET", "/my/resource", @@ -210,7 +220,10 @@ def test_aws_alb_scope_real( "type": "http", } - assert handler.body == scope_body + if handler.body: + assert handler.body == scope_body + else: + assert handler.body == b"" def test_aws_alb_set_cookies() -> None: diff --git a/tests/handlers/test_aws_api_gateway.py b/tests/handlers/test_aws_api_gateway.py index bc07d080..e35bf4e8 100644 --- a/tests/handlers/test_aws_api_gateway.py +++ b/tests/handlers/test_aws_api_gateway.py @@ -96,6 +96,8 @@ def test_aws_api_gateway_scope_basic(): } example_context = {} handler = AwsApiGateway(example_event, example_context) + + assert type(handler.body) == bytes assert handler.request.scope == { "asgi": {"version": "3.0"}, "aws.context": {}, @@ -128,7 +130,15 @@ def test_aws_api_gateway_scope_basic(): "query_string,scope_body", [ ("GET", "/hello/world", None, None, False, b"", None), - ("POST", "/", {"name": ["me"]}, None, False, b"name=me", None), + ( + "POST", + "/", + {"name": ["me"]}, + "field1=value1&field2=value2", + False, + b"name=me", + b"field1=value1&field2=value2", + ), ( "GET", "/my/resource", @@ -218,7 +228,10 @@ def test_aws_api_gateway_scope_real( "type": "http", } - assert handler.body == scope_body + if handler.body: + assert handler.body == scope_body + else: + assert handler.body == b"" @pytest.mark.parametrize( diff --git a/tests/handlers/test_aws_cf_lambda_at_edge.py b/tests/handlers/test_aws_cf_lambda_at_edge.py index 650e4e2d..4692e791 100644 --- a/tests/handlers/test_aws_cf_lambda_at_edge.py +++ b/tests/handlers/test_aws_cf_lambda_at_edge.py @@ -136,6 +136,7 @@ def test_aws_cf_lambda_at_edge_scope_basic(): example_context = {} handler = AwsCfLambdaAtEdge(example_event, example_context) + assert type(handler.body) == bytes assert handler.request.scope == { "asgi": {"version": "3.0"}, "aws.context": {}, @@ -169,7 +170,15 @@ def test_aws_cf_lambda_at_edge_scope_basic(): "body_base64_encoded,query_string,scope_body", [ ("GET", "/hello/world", None, None, False, b"", None), - ("POST", "/", {"name": ["me"]}, None, False, b"name=me", None), + ( + "POST", + "/", + {"name": ["me"]}, + "field1=value1&field2=value2", + False, + b"name=me", + b"field1=value1&field2=value2", + ), ( "GET", "/my/resource", @@ -241,7 +250,10 @@ def test_aws_api_gateway_scope_real( "type": "http", } - assert handler.body == scope_body + if handler.body: + assert handler.body == scope_body + else: + assert handler.body == b"" @pytest.mark.parametrize( diff --git a/tests/handlers/test_aws_http_gateway.py b/tests/handlers/test_aws_http_gateway.py index 56b9115e..4e87f9df 100644 --- a/tests/handlers/test_aws_http_gateway.py +++ b/tests/handlers/test_aws_http_gateway.py @@ -195,6 +195,8 @@ def test_aws_http_gateway_scope_basic_v1(): } example_context = {} handler = AwsHttpGateway(example_event, example_context) + + assert type(handler.body) == bytes assert handler.request.scope == { "asgi": {"version": "3.0"}, "aws.context": {}, @@ -297,6 +299,8 @@ def test_aws_http_gateway_scope_basic_v2(): } example_context = {} handler = AwsHttpGateway(example_event, example_context) + + assert type(handler.body) == bytes assert handler.request.scope == { "asgi": {"version": "3.0"}, "aws.context": {}, @@ -396,7 +400,10 @@ def test_aws_http_gateway_scope_real_v1( "type": "http", } - assert handler.body == scope_body + if handler.body: + assert handler.body == scope_body + else: + assert handler.body == b"" @pytest.mark.parametrize( @@ -461,7 +468,10 @@ def test_aws_http_gateway_scope_real_v2( "type": "http", } - assert handler.body == scope_body + if handler.body: + assert handler.body == scope_body + else: + assert handler.body == b"" @pytest.mark.parametrize(