File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -257,6 +257,13 @@ def __init__(
257257 f"must match block_sizes length ({ len (block_sizes )} )"
258258 )
259259
260+ # Align to a multiple of (128 / block_size) as required
261+ # by some attention backends such as TRTLLM (#39324)
262+ max_num_blocks = [
263+ cdiv (n , 128 // bs ) * (128 // bs ) if bs <= 128 else n
264+ for n , bs in zip (max_num_blocks , block_sizes )
265+ ]
266+
260267 self .block_tables = [
261268 BlockTable (
262269 block_size ,
Original file line number Diff line number Diff line change @@ -41,6 +41,11 @@ def __init__(
4141 # As a result, one block on the current rank covers `block_size * cp_size`
4242 # tokens in the full, global (unsharded) sequence.
4343 max_num_blocks = cdiv (self .max_model_len , block_size * self .cp_size )
44+ # Align to a multiple of (128 / block_size) as required
45+ # by some attention backends such as TRTLLM (#39324)
46+ if block_size <= 128 :
47+ alignment = 128 // block_size
48+ max_num_blocks = cdiv (max_num_blocks , alignment ) * alignment
4449 block_table = StagedWriteTensor (
4550 (self .max_num_reqs , max_num_blocks ),
4651 dtype = torch .int32 ,
You can’t perform that action at this time.
0 commit comments