@@ -133,11 +133,12 @@ class MultipartState(IntEnum):
133133# Mask for ASCII characters that can be http tokens.
134134# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
135135# and these: !#$%&'*+-.^_`|~
136- TOKEN_CHARS_SET = frozenset (
136+ TOKEN_CHARS = (
137137 b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
138138 b"abcdefghijklmnopqrstuvwxyz"
139139 b"0123456789"
140140 b"!#$%&'*+-.^_`|~" )
141+ TOKEN_CHARS_SET = frozenset (TOKEN_CHARS )
141142# fmt: on
142143
143144DEFAULT_MAX_HEADER_COUNT = 8
@@ -647,8 +648,7 @@ def callback(
647648 end: An integer that is passed to the data callback.
648649 start: An integer that is passed to the data callback.
649650 """
650- on_name = "on_" + name
651- func = self .callbacks .get (on_name )
651+ func = self .callbacks .get ("on_" + name )
652652 if func is None :
653653 return
654654 func = cast ("Callable[..., Any]" , func )
@@ -657,11 +657,8 @@ def callback(
657657 # Don't do anything if we have start == end.
658658 if start is not None and start == end :
659659 return
660-
661- self .logger .debug ("Calling %s with data[%d:%d]" , on_name , start , end )
662660 func (data , start , end )
663661 else :
664- self .logger .debug ("Calling %s with no data" , on_name )
665662 func ()
666663
667664 def set_callback (self , name : CallbackName , new_func : Callable [..., Any ] | None ) -> None :
@@ -1078,6 +1075,7 @@ def write(self, data: bytes) -> int:
10781075 def _internal_write (self , data : bytes , length : int ) -> int :
10791076 # Get values from locals.
10801077 boundary = self .boundary
1078+ boundary_length = len (boundary )
10811079
10821080 # Get our state, flags and index. These are persisted between calls to
10831081 # this function.
@@ -1128,7 +1126,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
11281126 # We need to use self.flags (and not flags) because we care about
11291127 # the state when we entered the loop.
11301128 lookbehind_len = - marked_index
1131- if lookbehind_len <= len ( boundary ) :
1129+ if lookbehind_len <= boundary_length :
11321130 self .callback (name , boundary , 0 , lookbehind_len )
11331131 elif self .flags & FLAG_PART_BOUNDARY :
11341132 lookback = boundary + b"\r \n "
@@ -1173,7 +1171,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
11731171 elif state == MultipartState .START_BOUNDARY :
11741172 # Check to ensure that the last 2 characters in our boundary
11751173 # are CRLF.
1176- if index == len ( boundary ) - 2 :
1174+ if index == boundary_length - 2 :
11771175 if c == HYPHEN :
11781176 # Potential empty message.
11791177 state = MultipartState .END_BOUNDARY
@@ -1185,7 +1183,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
11851183
11861184 index += 1
11871185
1188- elif index == len ( boundary ) - 2 + 1 :
1186+ elif index == boundary_length - 1 :
11891187 if c != LF :
11901188 msg = "Did not find LF at end of boundary (%d)" % (i ,)
11911189 self .logger .warning (msg )
@@ -1247,31 +1245,38 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
12471245 i += 1
12481246 continue
12491247
1250- # Increment our index in the header.
1251- index += 1
1248+ # The field name runs until the colon; jump straight to it and
1249+ # validate the whole span at once instead of byte by byte.
1250+ colon = data .find (b":" , i , length )
1251+ end = colon if colon != - 1 else length
1252+ field = data [i :end ]
1253+ if field .translate (None , TOKEN_CHARS ):
1254+ bad = next (b for b in field if b not in TOKEN_CHARS_SET )
1255+ bad_i = i + field .index (bad )
1256+ msg = "Found invalid character %r in header at %d" % (bad , bad_i )
1257+ self .logger .warning (msg )
1258+ raise MultipartParseError (msg , offset = bad_i )
12521259
1253- # If we've reached a colon, we're done with this header.
1254- if c == COLON :
1255- advance_header_size ()
1260+ index += end - i
1261+ if colon == - 1 :
1262+ # Field name continues into the next chunk.
1263+ advance_header_size (end - i )
1264+ i = length
1265+ else :
1266+ advance_header_size (end - i + 1 )
12561267 # A 0-length header is an error.
1257- if index == 1 :
1268+ if index == 0 :
12581269 msg = "Found 0-length header at %d" % (i ,)
12591270 self .logger .warning (msg )
12601271 raise MultipartParseError (msg , offset = i )
12611272
12621273 # Call our callback with the header field.
1274+ i = colon
12631275 data_callback ("header_field" , i )
12641276
12651277 # Move to parsing the header value.
12661278 state = MultipartState .HEADER_VALUE_START
12671279
1268- elif c not in TOKEN_CHARS_SET :
1269- msg = "Found invalid character %r in header at %d" % (c , i )
1270- self .logger .warning (msg )
1271- raise MultipartParseError (msg , offset = i )
1272- else :
1273- advance_header_size ()
1274-
12751280 elif state == MultipartState .HEADER_VALUE_START :
12761281 # Skip leading spaces.
12771282 if c == SPACE :
@@ -1287,15 +1292,19 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
12871292 i -= 1
12881293
12891294 elif state == MultipartState .HEADER_VALUE :
1290- # If we've got a CR, we're nearly done our headers. Otherwise,
1291- # we do nothing and just move past this character.
1292- if c == CR :
1295+ # The value runs until the terminating CR; jump straight to it
1296+ # instead of inspecting every byte.
1297+ cr = data .find (b"\r " , i , length )
1298+ end = cr if cr != - 1 else length
1299+ advance_header_size (end - i )
1300+ if cr != - 1 :
1301+ i = cr
12931302 data_callback ("header_value" , i )
12941303 self .callback ("header_end" )
12951304 current_header_size = 0
12961305 state = MultipartState .HEADER_VALUE_ALMOST_DONE
12971306 else :
1298- advance_header_size ()
1307+ i = length
12991308
13001309 elif state == MultipartState .HEADER_VALUE_ALMOST_DONE :
13011310 # The last character should be a LF. If not, it's an error.
@@ -1338,17 +1347,13 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
13381347 # find part of a boundary, but it doesn't match fully.
13391348 prev_index = index
13401349
1341- # Set up variables.
1342- boundary_length = len (boundary )
1343- data_length = length
1344-
13451350 # If our index is 0, we're starting a new part, so start our
13461351 # search.
13471352 if index == 0 :
13481353 # The most common case is likely to be that the whole
13491354 # boundary is present in the buffer.
13501355 # Calling `find` is much faster than iterating here.
1351- i0 = data .find (boundary , i , data_length )
1356+ i0 = data .find (boundary , i , length )
13521357 if i0 >= 0 :
13531358 # We matched the whole boundary string.
13541359 index = boundary_length - 1
@@ -1360,9 +1365,9 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
13601365 # Since the length to be searched is limited to the
13611366 # boundary length, scan the tail for boundary[0] via
13621367 # bytes.find (C-level) to keep cost off the Python loop.
1363- i = max (i , data_length - boundary_length )
1364- j = data .find (boundary [:1 ], i , data_length - 1 )
1365- i = j if j >= 0 else data_length - 1
1368+ i = max (i , length - boundary_length )
1369+ j = data .find (boundary [:1 ], i , length - 1 )
1370+ i = j if j >= 0 else length - 1
13661371
13671372 c = data [i ]
13681373
@@ -1456,7 +1461,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
14561461 i -= 1
14571462
14581463 elif state == MultipartState .END_BOUNDARY :
1459- if index == len ( boundary ) - 2 + 1 :
1464+ if index == boundary_length - 1 :
14601465 if c != HYPHEN :
14611466 msg = "Did not find - at end of boundary (%d)" % (i ,)
14621467 self .logger .warning (msg )
0 commit comments