Skip to content

Commit f4bda22

Browse files
authored
Fix high severity network attack vector in multipart implementation (#57)
1 parent 0a560a3 commit f4bda22

6 files changed

Lines changed: 137 additions & 12 deletions

File tree

baize/asgi/requests.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ async def json(self) -> Any:
176176

177177
raise UnsupportedMediaType("application/json")
178178

179+
async def _parse_multipart(self, boundary: bytes, charset: str) -> FormData:
180+
return FormData(
181+
await parse_multipart(
182+
self.stream(), boundary, charset, file_factory=UploadFile
183+
)
184+
)
185+
179186
@cached_property
180187
async def form(self) -> FormData:
181188
"""
@@ -193,11 +200,7 @@ async def form(self) -> FormData:
193200
if "boundary" not in self.content_type.options:
194201
raise MalformedMultipart("Missing boundary in header content-type")
195202
boundary = self.content_type.options["boundary"].encode("latin-1")
196-
return FormData(
197-
await parse_multipart(
198-
self.stream(), boundary, charset, file_factory=UploadFile
199-
)
200-
)
203+
return await self._parse_multipart(boundary, charset)
201204
if self.content_type == "application/x-www-form-urlencoded":
202205
body = (await self.body).decode(
203206
encoding=self.content_type.options.get("charset", "latin-1")

baize/exceptions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ def abort(
3636
raise HTTPException(status_code=status_code, headers=headers, content=content)
3737

3838

39+
class RequestEntityTooLarge(HTTPException[None]):
40+
"""
41+
413 Request Entity Too Large
42+
"""
43+
44+
def __init__(self, retry_after: Optional[int] = None) -> None:
45+
super().__init__(
46+
413, {"Retry-After": str(retry_after)} if retry_after is not None else None
47+
)
48+
49+
3950
class UnsupportedMediaType(HTTPException[None]):
4051
"""
4152
415 Unsupported Media Type

baize/multipart_helper.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import AsyncIterable, Iterable, List, Optional, Tuple, Type, TypeVar, Union
22

33
from .datastructures import Headers
4+
from .exceptions import RequestEntityTooLarge
45
from .multipart import (
56
Data,
67
Epilogue,
@@ -47,6 +48,8 @@ async def parse_async_stream(
4748
charset: str,
4849
*,
4950
file_factory: Type[_AsyncUploadFile],
51+
max_form_parts: int = 324,
52+
max_form_memory_size: Union[int, None] = None,
5053
) -> List[Tuple[str, Union[str, _AsyncUploadFile]]]:
5154
"""
5255
Parse an asynchronous stream in multipart format
@@ -60,6 +63,8 @@ async def parse_async_stream(
6063
field_name = ""
6164
data = bytearray()
6265
file: Optional[_AsyncUploadFile] = None
66+
form_parts_count = 0
67+
form_memory_size_count = 0
6368

6469
items: List[Tuple[str, Union[str, _AsyncUploadFile]]] = []
6570

@@ -77,6 +82,14 @@ async def parse_async_stream(
7782
elif isinstance(event, Data):
7883
if file is None:
7984
data.extend(event.data)
85+
86+
# Check if we have exceeded the maximum memory size
87+
form_memory_size_count += len(event.data)
88+
if (
89+
max_form_memory_size is not None
90+
and form_memory_size_count > max_form_memory_size
91+
):
92+
raise RequestEntityTooLarge()
8093
else:
8194
await file.awrite(event.data)
8295

@@ -88,6 +101,11 @@ async def parse_async_stream(
88101
await file.aseek(0)
89102
items.append((field_name, file))
90103
file = None
104+
105+
# Check if we have exceeded the maximum number of form parts
106+
form_parts_count += 1
107+
if form_parts_count > max_form_parts:
108+
raise RequestEntityTooLarge()
91109
return items
92110

93111

@@ -97,6 +115,8 @@ def parse_stream(
97115
charset: str,
98116
*,
99117
file_factory: Type[_SyncUploadFile],
118+
max_form_parts: int = 324,
119+
max_form_memory_size: Union[int, None] = None,
100120
) -> List[Tuple[str, Union[str, _SyncUploadFile]]]:
101121
"""
102122
Parse a synchronous stream in multipart format
@@ -110,6 +130,8 @@ def parse_stream(
110130
field_name = ""
111131
data = bytearray()
112132
file: Optional[_SyncUploadFile] = None
133+
form_parts_count = 0
134+
form_memory_size_count = 0
113135

114136
items: List[Tuple[str, Union[str, _SyncUploadFile]]] = []
115137

@@ -127,6 +149,14 @@ def parse_stream(
127149
elif isinstance(event, Data):
128150
if file is None:
129151
data.extend(event.data)
152+
153+
# Check if we have exceeded the maximum memory size
154+
form_memory_size_count += len(event.data)
155+
if (
156+
max_form_memory_size is not None
157+
and form_memory_size_count > max_form_memory_size
158+
):
159+
raise RequestEntityTooLarge()
130160
else:
131161
file.write(event.data)
132162

@@ -138,4 +168,9 @@ def parse_stream(
138168
file.seek(0)
139169
items.append((field_name, file))
140170
file = None
171+
172+
# Check if we have exceeded the maximum number of form parts
173+
form_parts_count += 1
174+
if form_parts_count > max_form_parts:
175+
raise RequestEntityTooLarge()
141176
return items

baize/wsgi/requests.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ def json(self) -> Any:
168168

169169
raise UnsupportedMediaType("application/json")
170170

171+
def _parse_multipart(self, boundary: bytes, charset: str) -> FormData:
172+
return FormData(
173+
parse_multipart(self.stream(), boundary, charset, file_factory=UploadFile)
174+
)
175+
171176
@cached_property
172177
def form(self) -> FormData:
173178
"""
@@ -185,11 +190,7 @@ def form(self) -> FormData:
185190
if "boundary" not in self.content_type.options:
186191
raise MalformedMultipart("Missing boundary in header content-type")
187192
boundary = self.content_type.options["boundary"].encode("latin-1")
188-
return FormData(
189-
parse_multipart(
190-
self.stream(), boundary, charset, file_factory=UploadFile
191-
)
192-
)
193+
return self._parse_multipart(boundary, charset)
193194
if self.content_type == "application/x-www-form-urlencoded":
194195
body = self.body.decode(
195196
encoding=self.content_type.options.get("charset", "latin-1")

tests/test_asgi.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@
3232
request_response,
3333
websocket_session,
3434
)
35-
from baize.datastructures import UploadFile
35+
from baize.datastructures import FormData, UploadFile
3636
from baize.exceptions import (
3737
HTTPException,
3838
MalformedJSON,
3939
MalformedMultipart,
40+
RequestEntityTooLarge,
4041
UnsupportedMediaType,
4142
)
4243
from baize.typing import Message, ServerSentEvent
@@ -221,6 +222,44 @@ async def app(scope, receive, send):
221222
)
222223

223224

225+
@pytest.mark.asyncio
226+
async def test_request_multipart_form_limit():
227+
class FRequest(Request):
228+
async def _parse_multipart(self, boundary: bytes, charset: str) -> FormData:
229+
from baize.multipart_helper import parse_async_stream as parse_multipart
230+
231+
return FormData(
232+
await parse_multipart(
233+
self.stream(),
234+
boundary,
235+
charset,
236+
file_factory=UploadFile,
237+
max_form_parts=2,
238+
max_form_memory_size=1024,
239+
)
240+
)
241+
242+
async def app(scope, receive, send):
243+
await FRequest(scope, receive).form
244+
245+
async with httpx.AsyncClient(app=app, base_url="http://testServer/") as client:
246+
with pytest.raises(RequestEntityTooLarge):
247+
with tempfile.SpooledTemporaryFile(1024) as file:
248+
file.write(b"temporary file")
249+
file.seek(0, 0)
250+
await client.post(
251+
"/", data={"part1": "1", "part2": "2"}, files={"file-key": file}
252+
)
253+
254+
with pytest.raises(RequestEntityTooLarge):
255+
with tempfile.SpooledTemporaryFile(1024) as file:
256+
file.write(b"temporary file")
257+
file.seek(0, 0)
258+
await client.post(
259+
"/", data={"abc": "*" * 2048}, files={"file-key": file}
260+
)
261+
262+
224263
@pytest.mark.asyncio
225264
async def test_request_body_then_stream():
226265
async def app(scope, receive, send):

tests/test_wsgi.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import httpx
88
import pytest
99

10-
from baize.datastructures import Address, UploadFile
10+
from baize.datastructures import Address, FormData, UploadFile
1111
from baize.exceptions import (
1212
HTTPException,
1313
MalformedJSON,
1414
MalformedMultipart,
15+
RequestEntityTooLarge,
1516
UnsupportedMediaType,
1617
)
1718
from baize.typing import ServerSentEvent
@@ -194,6 +195,41 @@ def app(environ, start_response):
194195
)
195196

196197

198+
def test_request_multipart_form_limit():
199+
class FRequest(Request):
200+
def _parse_multipart(self, boundary: bytes, charset: str) -> FormData:
201+
from baize.multipart_helper import parse_stream as parse_multipart
202+
203+
return FormData(
204+
parse_multipart(
205+
self.stream(),
206+
boundary,
207+
charset,
208+
file_factory=UploadFile,
209+
max_form_parts=2,
210+
max_form_memory_size=1024,
211+
)
212+
)
213+
214+
def app(environ, start_response):
215+
FRequest(environ).form
216+
217+
with httpx.Client(app=app, base_url="http://testServer/") as client:
218+
with pytest.raises(RequestEntityTooLarge):
219+
with tempfile.SpooledTemporaryFile(1024) as file:
220+
file.write(b"temporary file")
221+
file.seek(0, 0)
222+
client.post(
223+
"/", data={"part1": "1", "part2": "2"}, files={"file-key": file}
224+
)
225+
226+
with pytest.raises(RequestEntityTooLarge):
227+
with tempfile.SpooledTemporaryFile(1024) as file:
228+
file.write(b"temporary file")
229+
file.seek(0, 0)
230+
client.post("/", data={"abc": "*" * 2048}, files={"file-key": file})
231+
232+
197233
def test_request_body_then_stream():
198234
def app(environ, start_response):
199235
request = Request(environ)

0 commit comments

Comments
 (0)