Skip to content

Commit e09f4f1

Browse files
committed
Refactor multipart buffered reading
1 parent b392f5a commit e09f4f1

3 files changed

Lines changed: 162 additions & 21 deletions

File tree

.github/workflows/release.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ jobs:
137137
shell: bash
138138
- name: Generate PGO data
139139
shell: bash
140+
env:
141+
PGO_RUN: y
140142
run: |
141143
uv python install ${{ env.UV_PYTHON }}
142144
uv venv .venv

src/multipart/parse.rs

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use pyo3::{IntoPyObjectExt, exceptions::PyStopIteration, prelude::*, types::PyBy
99
use std::{
1010
borrow::Cow,
1111
collections::VecDeque,
12-
io::{BufRead, Cursor, Read},
12+
io::{BufRead, Cursor, Read, Write},
1313
mem,
1414
sync::Mutex,
1515
};
@@ -27,6 +27,7 @@ enum MultiPartParserState {
2727
Value(Part),
2828
File(FilePart),
2929
Skip,
30+
Consumed,
3031
}
3132

3233
impl Default for MultiPartParserState {
@@ -41,8 +42,8 @@ struct MultiPartParser {
4142
max_part_size: usize,
4243
state: MultiPartParserState,
4344
buffer: Vec<u8>,
45+
bufshift: usize,
4446
read_size: usize,
45-
pub consumed: bool,
4647
stack: VecDeque<Node>,
4748
}
4849

@@ -54,8 +55,8 @@ impl MultiPartParser {
5455
max_part_size,
5556
state: MultiPartParserState::Clean,
5657
buffer: Vec::new(),
58+
bufshift: 0,
5759
read_size: 0,
58-
consumed: false,
5960
stack: VecDeque::new(),
6061
}
6162
}
@@ -65,26 +66,79 @@ impl MultiPartParser {
6566
T: AsRef<[u8]>,
6667
{
6768
macro_rules! buffered_read {
69+
($boundary:expr) => {{
70+
let peeker = reader.fill_buf()?;
71+
if peeker.is_empty() {
72+
return Ok(());
73+
}
74+
75+
// if the chunk is not long enough to check for boundary, buffer
76+
if (peeker.len() + self.buffer.len()) < $boundary.len() {
77+
reader.read_to_end(&mut self.buffer)?;
78+
return Ok(());
79+
}
80+
81+
let (readn, found) = if self.buffer.is_empty() {
82+
reader.stream_until_token($boundary, &mut self.buffer)?
83+
} else {
84+
// we buffered previous contents, chain the two reads
85+
let mut buf = Vec::new();
86+
let mut chain = self.buffer.chain(&mut *reader);
87+
let ret = chain.stream_until_token($boundary, &mut buf)?;
88+
self.buffer.truncate(0);
89+
self.buffer.extend(buf);
90+
ret
91+
};
92+
if !found {
93+
let bdiff = self.buffer.len() + self.bufshift;
94+
if bdiff < readn {
95+
let shift = readn - bdiff;
96+
self.buffer.extend(&$boundary[..self.bufshift + shift]);
97+
self.bufshift += shift;
98+
} else {
99+
self.bufshift = 0;
100+
}
101+
} else {
102+
self.bufshift = 0;
103+
}
104+
(readn, found)
105+
}};
106+
68107
($boundary:expr, $target:expr) => {{
69108
let peeker = reader.fill_buf()?;
70109
if peeker.is_empty() {
71110
return Ok(());
72111
}
112+
73113
// if the chunk is not long enough to check for boundary, buffer
74114
if (peeker.len() + self.buffer.len()) < $boundary.len() {
75115
reader.read_to_end(&mut self.buffer)?;
76116
return Ok(());
77117
}
78118

79-
if self.buffer.is_empty() {
119+
let (readn, found) = if self.buffer.is_empty() {
80120
reader.stream_until_token($boundary, $target)?
81121
} else {
82122
// we buffered previous contents, chain the two reads
83123
let mut chain = self.buffer.chain(&mut *reader);
84124
let ret = chain.stream_until_token($boundary, $target)?;
85125
self.buffer.truncate(0);
86126
ret
127+
};
128+
if !found {
129+
// keep incomplete boundary segment in buffer
130+
let bdiff = $target.len() + self.bufshift;
131+
if bdiff < readn {
132+
let shift = readn - bdiff;
133+
self.buffer.extend(&$boundary[..self.bufshift + shift]);
134+
self.bufshift += shift;
135+
} else {
136+
self.bufshift = 0;
137+
}
138+
} else {
139+
self.bufshift = 0;
87140
}
141+
(readn, found)
88142
}};
89143
}
90144

@@ -93,24 +147,26 @@ impl MultiPartParser {
93147
loop {
94148
if let MultiPartParserState::Clean = self.state {
95149
let peeker = reader.fill_buf()?;
96-
97-
// If the last chunk is empty and we're in clean state there's nothing to do.
98-
if peeker.is_empty() {
150+
if (self.buffer.len() + peeker.len()) < 2 {
151+
self.buffer.extend(peeker);
99152
return Ok(());
100153
}
101154

102155
// If the next two lookahead characters are '--', parsing is finished.
103-
if peeker.len() >= 2 && &peeker[..2] == b"--" {
104-
self.consumed = true;
156+
let mut buf = vec![0; 2];
157+
let mut chain = self.buffer.chain(peeker);
158+
chain.read_exact(&mut buf)?;
159+
if buf.len() >= 2 && &buf[..2] == b"--" {
160+
self.state = MultiPartParserState::Consumed;
105161
return Ok(());
106162
}
107163

108164
self.state = MultiPartParserState::Termination;
109165
}
110166

111167
if let MultiPartParserState::Termination = self.state {
112-
// Read the line terminator after the boundary
113-
let (_, found) = reader.stream_until_token(lt, &mut self.buffer)?;
168+
let (_, found) = buffered_read!(lt);
169+
114170
if !found {
115171
return Ok(());
116172
}
@@ -120,8 +176,7 @@ impl MultiPartParser {
120176
}
121177

122178
if let MultiPartParserState::Headers = self.state {
123-
// Read the headers (which end in 2 line terminators)
124-
let (_, found) = reader.stream_until_token(ltlt, &mut self.buffer)?;
179+
let (_, found) = buffered_read!(ltlt);
125180
if !found {
126181
return Ok(());
127182
}
@@ -202,12 +257,14 @@ impl MultiPartParser {
202257
}
203258

204259
if let MultiPartParserState::File(filepart) = &mut self.state {
205-
let (read, found) = buffered_read!(
206-
lt_boundary,
207-
&mut filepart.file.as_mut().expect("uninitialized file part")
208-
);
209-
let size = filepart.size.unwrap_or(0);
210-
filepart.size = Some(size + read);
260+
let mut buf = Vec::new();
261+
let (read, found) = buffered_read!(lt_boundary, &mut buf);
262+
filepart
263+
.file
264+
.as_mut()
265+
.expect("uninitialized file part")
266+
.write_all(&buf)?;
267+
filepart.size = Some(filepart.size.unwrap_or(0) + read);
211268

212269
if !found {
213270
return Ok(());
@@ -225,7 +282,7 @@ impl MultiPartParser {
225282
}
226283

227284
if let MultiPartParserState::Skip = &mut self.state {
228-
let (_, found) = reader.stream_until_token(lt_boundary, &mut self.buffer)?;
285+
let (_, found) = buffered_read!(lt_boundary);
229286
if !found {
230287
return Ok(());
231288
}
@@ -262,6 +319,9 @@ impl MultiPartReader {
262319
let mut guard = self.inner.lock().unwrap();
263320

264321
if let Some(inner) = &mut *guard {
322+
if matches!(inner.state, MultiPartParserState::Consumed) {
323+
return Ok(());
324+
}
265325
let mut reader = Cursor::new(data);
266326
return inner.parse_chunk(&mut reader);
267327
}
@@ -297,7 +357,10 @@ impl MultiPartReader {
297357
let mut guard = self.inner.lock().unwrap();
298358

299359
if let Some(mut inner) = guard.take() {
300-
if !inner.consumed {
360+
if !matches!(
361+
inner.state,
362+
MultiPartParserState::Clean | MultiPartParserState::Consumed
363+
) {
301364
return Err(error_state!());
302365
}
303366
let nodes = mem::take(&mut inner.stack);
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
3+
import pytest
4+
5+
from emmett_core._emmett_core import MultiPartReader
6+
7+
8+
@pytest.mark.skipif(bool(os.getenv("PGO_RUN")), reason="PGO build")
9+
def test_multipart_mixed_segmented():
10+
data = (
11+
# data
12+
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
13+
b'Content-Disposition: form-data; name="field0"\r\n\r\n'
14+
b"value0\r\n"
15+
# file
16+
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
17+
b'Content-Disposition: form-data; name="file"; filename="file.txt"\r\n'
18+
b"Content-Type: text/plain\r\n\r\n"
19+
b"<file content>\r\n"
20+
# data
21+
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
22+
b'Content-Disposition: form-data; name="field1"\r\n\r\n'
23+
b"value1\r\n"
24+
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
25+
)
26+
27+
parser = MultiPartReader("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")
28+
parser.parse(data[:37])
29+
30+
idx = 37
31+
while True:
32+
segment = data[idx : idx + 1]
33+
if not segment:
34+
break
35+
parser.parse(segment)
36+
idx += 1
37+
parsed = list(parser.contents())
38+
assert (parsed[0][0], parsed[0][2]) == ("field0", b"value0")
39+
assert (parsed[2][0], parsed[2][2]) == ("field1", b"value1")
40+
assert (parsed[1][0], parsed[1][2].filename, parsed[1][2].read()) == ("file", "file.txt", b"<file content>")
41+
42+
43+
def test_multipart_mixed_chunked():
44+
data = (
45+
# data
46+
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
47+
b'Content-Disposition: form-data; name="field0"\r\n\r\n'
48+
b"value0\r\n"
49+
# file
50+
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
51+
b'Content-Disposition: form-data; name="file"; filename="file.txt"\r\n'
52+
b"Content-Type: text/plain\r\n\r\n"
53+
b"<file content>\r\n"
54+
# data
55+
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
56+
b'Content-Disposition: form-data; name="field1"\r\n\r\n'
57+
b"value1\r\n"
58+
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
59+
)
60+
61+
step = 97
62+
63+
parser = MultiPartReader("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")
64+
parser.parse(data[:step])
65+
66+
idx = 1
67+
while True:
68+
segment = data[idx * step : (idx + 1) * step]
69+
if not segment:
70+
break
71+
parser.parse(segment)
72+
idx += 1
73+
parsed = list(parser.contents())
74+
assert (parsed[0][0], parsed[0][2]) == ("field0", b"value0")
75+
assert (parsed[2][0], parsed[2][2]) == ("field1", b"value1")
76+
assert (parsed[1][0], parsed[1][2].filename, parsed[1][2].read()) == ("file", "file.txt", b"<file content>")

0 commit comments

Comments
 (0)