Skip to content

Commit 38e1667

Browse files
authored
[Bugfix] Align block table for TRTLLM MLA edge-case (vllm-project#39324)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
1 parent 27702f6 commit 38e1667

2 files changed

Lines changed: 12 additions & 0 deletions

File tree

vllm/v1/worker/block_table.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,13 @@ def __init__(
257257
f"must match block_sizes length ({len(block_sizes)})"
258258
)
259259

260+
# Align to a multiple of (128 / block_size) as required
261+
# by some attention backends such as TRTLLM (#39324)
262+
max_num_blocks = [
263+
cdiv(n, 128 // bs) * (128 // bs) if bs <= 128 else n
264+
for n, bs in zip(max_num_blocks, block_sizes)
265+
]
266+
260267
self.block_tables = [
261268
BlockTable(
262269
block_size,

vllm/v1/worker/gpu/block_table.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def __init__(
4141
# As a result, one block on the current rank covers `block_size * cp_size`
4242
# tokens in the full, global (unsharded) sequence.
4343
max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
44+
# Align to a multiple of (128 / block_size) as required
45+
# by some attention backends such as TRTLLM (#39324)
46+
if block_size <= 128:
47+
alignment = 128 // block_size
48+
max_num_blocks = cdiv(max_num_blocks, alignment) * alignment
4449
block_table = StagedWriteTensor(
4550
(self.max_num_reqs, max_num_blocks),
4651
dtype=torch.int32,

0 commit comments

Comments
 (0)