Skip to content

Commit 953b66b

Browse files
committed
Minor enhancements on multipart files handling
1 parent f616eab commit 953b66b

4 files changed

Lines changed: 78 additions & 10 deletions

File tree

emmett_core/http/wrappers/helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class FileStorage:
5656
def __init__(self, file):
5757
self.file = file
5858

59+
def __iter__(self):
60+
return self.file.__iter__()
61+
5962
def __getattr__(self, name):
6063
return getattr(self.file, name)
6164

src/multipart/parse.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ impl MultiPartParser {
175175
}
176176

177177
if let MultiPartParserState::File(filepart) = &mut self.state {
178+
// potentially allow py threads?
178179
let (read, found) = reader.stream_until_token(
179180
lt_boundary,
180181
&mut filepart.file.as_mut().expect("uninitialized file part"),
@@ -189,6 +190,8 @@ impl MultiPartParser {
189190
let state = mem::take(&mut self.state);
190191
match state {
191192
MultiPartParserState::File(part) => {
193+
// potentially allow py threads?
194+
part.file.as_ref().unwrap().sync_data()?;
192195
self.stack.push_back(Node::File(part));
193196
}
194197
_ => unreachable!(),

src/multipart/parts.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use http::header::{self, HeaderMap};
33
use pyo3::{exceptions::PyStopIteration, prelude::*, types::PyBytes};
44
use std::{
55
fs::File,
6-
io::{BufRead, BufReader, Read},
6+
io::{BufReader, Read},
77
path::PathBuf,
88
sync::Mutex,
99
};
@@ -117,7 +117,7 @@ impl FilePartReader {
117117
drop(inner.file.take().expect("uninitialized file part"));
118118
let file = File::open(inner.path.clone()).map_err::<anyhow::Error, _>(|_| error_io!())?;
119119
let size = file.metadata().unwrap().len();
120-
let reader = Mutex::new(BufReader::with_capacity(4096, file));
120+
let reader = Mutex::new(BufReader::with_capacity(131_072, file));
121121
Ok(Self { inner, reader, size })
122122
}
123123

@@ -128,7 +128,6 @@ impl FilePartReader {
128128
let mut len_read = 0;
129129

130130
while len_read < size {
131-
guard.fill_buf()?;
132131
let rsize = guard.read(&mut buf[len_read..])?;
133132
if rsize == 0 {
134133
break;
@@ -143,8 +142,6 @@ impl FilePartReader {
143142

144143
fn read_all(&self) -> Result<Vec<u8>> {
145144
let mut guard = self.reader.lock().unwrap();
146-
147-
guard.fill_buf()?;
148145
let mut buf = Vec::new();
149146
guard.read_to_end(&mut buf)?;
150147
Ok(buf)
@@ -174,8 +171,8 @@ impl FilePartReader {
174171
#[pyo3(signature = (size = None))]
175172
fn read<'p>(&self, py: Python<'p>, size: Option<usize>) -> Result<Bound<'p, PyBytes>> {
176173
let buf = match size {
177-
Some(size) => self.read_chunk(size),
178-
None => self.read_all(),
174+
Some(size) => py.allow_threads(|| self.read_chunk(size)),
175+
None => py.allow_threads(|| self.read_all()),
179176
}?;
180177
Ok(PyBytes::new(py, &buf[..]))
181178
}
@@ -185,7 +182,7 @@ impl FilePartReader {
185182
}
186183

187184
fn __next__<'p>(&self, py: Python<'p>) -> Result<Bound<'p, PyBytes>> {
188-
let buf = self.read_chunk(4096)?;
185+
let buf = py.allow_threads(|| self.read_chunk(131_072))?;
189186
if buf.is_empty() {
190187
return Err(PyStopIteration::new_err(py.None()).into());
191188
}

tests/multipart/test_multipart.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,37 @@ async def multipart():
5757
return client
5858

5959

60+
@pytest.fixture(scope="function")
61+
def multipart_stream_client(current, client, tmpdir: Path):
62+
app = client.application
63+
64+
@app.route("/", output="str")
65+
async def multipart():
66+
target = tmpdir / "save.txt"
67+
files = await current.request.files
68+
with target.open("wb") as tf:
69+
for chunk in files.test:
70+
tf.write(chunk)
71+
return ""
72+
73+
return client
74+
75+
76+
@pytest.fixture(scope="function")
77+
def multipart_copy_client(current, client, tmpdir: Path):
78+
app = client.application
79+
80+
@app.route("/", output="str")
81+
async def multipart():
82+
target = tmpdir / "save.txt"
83+
files = await current.request.files
84+
with target.open("wb") as tf:
85+
tf.write(files.test.read())
86+
return ""
87+
88+
return client
89+
90+
6091
def test_multipart_request_data(multipart_client):
6192
response = multipart_client.post("/", data={"some": "data"}, content_type="multipart/form-data")
6293
assert response.json() == {"params": {"some": ["data"]}, "files": {}}
@@ -340,8 +371,8 @@ def test_multipart_request_file_save(tmpdir: Path, multipart_save_client):
340371
target = tmpdir / "save.txt"
341372
with path.open("wb") as file:
342373
file.write(b"<")
343-
for i in range(8192):
344-
file.write(f"{i}".zfill(5).encode("utf8"))
374+
for i in range(8192 * 128):
375+
file.write(f"{i}".zfill(7).encode("utf8"))
345376
file.write(b">")
346377

347378
with path.open("rb") as f:
@@ -350,3 +381,37 @@ def test_multipart_request_file_save(tmpdir: Path, multipart_save_client):
350381

351382
with path.open("rb") as f1, target.open("rb") as f2:
352383
assert f1.read() == f2.read()
384+
385+
386+
def test_multipart_request_file_stream(tmpdir: Path, multipart_stream_client):
387+
path = tmpdir / "test.txt"
388+
target = tmpdir / "save.txt"
389+
with path.open("wb") as file:
390+
file.write(b"<")
391+
for i in range(8192 * 128):
392+
file.write(f"{i}".zfill(7).encode("utf8"))
393+
file.write(b">")
394+
395+
with path.open("rb") as f:
396+
response = multipart_stream_client.post("/", data={"test": (f, "test.txt", "text/plain")})
397+
assert response.status == 200
398+
399+
with path.open("rb") as f1, target.open("rb") as f2:
400+
assert f1.read() == f2.read()
401+
402+
403+
def test_multipart_request_file_copy(tmpdir: Path, multipart_copy_client):
404+
path = tmpdir / "test.txt"
405+
target = tmpdir / "save.txt"
406+
with path.open("wb") as file:
407+
file.write(b"<")
408+
for i in range(8192 * 128):
409+
file.write(f"{i}".zfill(7).encode("utf8"))
410+
file.write(b">")
411+
412+
with path.open("rb") as f:
413+
response = multipart_copy_client.post("/", data={"test": (f, "test.txt", "text/plain")})
414+
assert response.status == 200
415+
416+
with path.open("rb") as f1, target.open("rb") as f2:
417+
assert f1.read() == f2.read()

0 commit comments

Comments
 (0)