Skip to content

Commit 6732164

Browse files
authored
Speed up multipart header parsing and callback dispatch (#295)
1 parent 9d3ead5 commit 6732164

1 file changed

Lines changed: 40 additions & 35 deletions

File tree

python_multipart/multipart.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,12 @@ class MultipartState(IntEnum):
133133
# Mask for ASCII characters that can be http tokens.
134134
# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
135135
# and these: !#$%&'*+-.^_`|~
136-
TOKEN_CHARS_SET = frozenset(
136+
TOKEN_CHARS = (
137137
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
138138
b"abcdefghijklmnopqrstuvwxyz"
139139
b"0123456789"
140140
b"!#$%&'*+-.^_`|~")
141+
TOKEN_CHARS_SET = frozenset(TOKEN_CHARS)
141142
# fmt: on
142143

143144
DEFAULT_MAX_HEADER_COUNT = 8
@@ -647,8 +648,7 @@ def callback(
647648
end: An integer that is passed to the data callback.
648649
start: An integer that is passed to the data callback.
649650
"""
650-
on_name = "on_" + name
651-
func = self.callbacks.get(on_name)
651+
func = self.callbacks.get("on_" + name)
652652
if func is None:
653653
return
654654
func = cast("Callable[..., Any]", func)
@@ -657,11 +657,8 @@ def callback(
657657
# Don't do anything if we have start == end.
658658
if start is not None and start == end:
659659
return
660-
661-
self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end)
662660
func(data, start, end)
663661
else:
664-
self.logger.debug("Calling %s with no data", on_name)
665662
func()
666663

667664
def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None:
@@ -1078,6 +1075,7 @@ def write(self, data: bytes) -> int:
10781075
def _internal_write(self, data: bytes, length: int) -> int:
10791076
# Get values from locals.
10801077
boundary = self.boundary
1078+
boundary_length = len(boundary)
10811079

10821080
# Get our state, flags and index. These are persisted between calls to
10831081
# this function.
@@ -1128,7 +1126,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
11281126
# We need to use self.flags (and not flags) because we care about
11291127
# the state when we entered the loop.
11301128
lookbehind_len = -marked_index
1131-
if lookbehind_len <= len(boundary):
1129+
if lookbehind_len <= boundary_length:
11321130
self.callback(name, boundary, 0, lookbehind_len)
11331131
elif self.flags & FLAG_PART_BOUNDARY:
11341132
lookback = boundary + b"\r\n"
@@ -1173,7 +1171,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
11731171
elif state == MultipartState.START_BOUNDARY:
11741172
# Check to ensure that the last 2 characters in our boundary
11751173
# are CRLF.
1176-
if index == len(boundary) - 2:
1174+
if index == boundary_length - 2:
11771175
if c == HYPHEN:
11781176
# Potential empty message.
11791177
state = MultipartState.END_BOUNDARY
@@ -1185,7 +1183,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
11851183

11861184
index += 1
11871185

1188-
elif index == len(boundary) - 2 + 1:
1186+
elif index == boundary_length - 1:
11891187
if c != LF:
11901188
msg = "Did not find LF at end of boundary (%d)" % (i,)
11911189
self.logger.warning(msg)
@@ -1247,31 +1245,38 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
12471245
i += 1
12481246
continue
12491247

1250-
# Increment our index in the header.
1251-
index += 1
1248+
# The field name runs until the colon; jump straight to it and
1249+
# validate the whole span at once instead of byte by byte.
1250+
colon = data.find(b":", i, length)
1251+
end = colon if colon != -1 else length
1252+
field = data[i:end]
1253+
if field.translate(None, TOKEN_CHARS):
1254+
bad = next(b for b in field if b not in TOKEN_CHARS_SET)
1255+
bad_i = i + field.index(bad)
1256+
msg = "Found invalid character %r in header at %d" % (bad, bad_i)
1257+
self.logger.warning(msg)
1258+
raise MultipartParseError(msg, offset=bad_i)
12521259

1253-
# If we've reached a colon, we're done with this header.
1254-
if c == COLON:
1255-
advance_header_size()
1260+
index += end - i
1261+
if colon == -1:
1262+
# Field name continues into the next chunk.
1263+
advance_header_size(end - i)
1264+
i = length
1265+
else:
1266+
advance_header_size(end - i + 1)
12561267
# A 0-length header is an error.
1257-
if index == 1:
1268+
if index == 0:
12581269
msg = "Found 0-length header at %d" % (i,)
12591270
self.logger.warning(msg)
12601271
raise MultipartParseError(msg, offset=i)
12611272

12621273
# Call our callback with the header field.
1274+
i = colon
12631275
data_callback("header_field", i)
12641276

12651277
# Move to parsing the header value.
12661278
state = MultipartState.HEADER_VALUE_START
12671279

1268-
elif c not in TOKEN_CHARS_SET:
1269-
msg = "Found invalid character %r in header at %d" % (c, i)
1270-
self.logger.warning(msg)
1271-
raise MultipartParseError(msg, offset=i)
1272-
else:
1273-
advance_header_size()
1274-
12751280
elif state == MultipartState.HEADER_VALUE_START:
12761281
# Skip leading spaces.
12771282
if c == SPACE:
@@ -1287,15 +1292,19 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
12871292
i -= 1
12881293

12891294
elif state == MultipartState.HEADER_VALUE:
1290-
# If we've got a CR, we're nearly done our headers. Otherwise,
1291-
# we do nothing and just move past this character.
1292-
if c == CR:
1295+
# The value runs until the terminating CR; jump straight to it
1296+
# instead of inspecting every byte.
1297+
cr = data.find(b"\r", i, length)
1298+
end = cr if cr != -1 else length
1299+
advance_header_size(end - i)
1300+
if cr != -1:
1301+
i = cr
12931302
data_callback("header_value", i)
12941303
self.callback("header_end")
12951304
current_header_size = 0
12961305
state = MultipartState.HEADER_VALUE_ALMOST_DONE
12971306
else:
1298-
advance_header_size()
1307+
i = length
12991308

13001309
elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
13011310
# The last character should be a LF. If not, it's an error.
@@ -1338,17 +1347,13 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
13381347
# find part of a boundary, but it doesn't match fully.
13391348
prev_index = index
13401349

1341-
# Set up variables.
1342-
boundary_length = len(boundary)
1343-
data_length = length
1344-
13451350
# If our index is 0, we're starting a new part, so start our
13461351
# search.
13471352
if index == 0:
13481353
# The most common case is likely to be that the whole
13491354
# boundary is present in the buffer.
13501355
# Calling `find` is much faster than iterating here.
1351-
i0 = data.find(boundary, i, data_length)
1356+
i0 = data.find(boundary, i, length)
13521357
if i0 >= 0:
13531358
# We matched the whole boundary string.
13541359
index = boundary_length - 1
@@ -1360,9 +1365,9 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
13601365
# Since the length to be searched is limited to the
13611366
# boundary length, scan the tail for boundary[0] via
13621367
# bytes.find (C-level) to keep cost off the Python loop.
1363-
i = max(i, data_length - boundary_length)
1364-
j = data.find(boundary[:1], i, data_length - 1)
1365-
i = j if j >= 0 else data_length - 1
1368+
i = max(i, length - boundary_length)
1369+
j = data.find(boundary[:1], i, length - 1)
1370+
i = j if j >= 0 else length - 1
13661371

13671372
c = data[i]
13681373

@@ -1456,7 +1461,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
14561461
i -= 1
14571462

14581463
elif state == MultipartState.END_BOUNDARY:
1459-
if index == len(boundary) - 2 + 1:
1464+
if index == boundary_length - 1:
14601465
if c != HYPHEN:
14611466
msg = "Did not find - at end of boundary (%d)" % (i,)
14621467
self.logger.warning(msg)

0 commit comments

Comments
 (0)