Skip to content

Commit 27702f6

Browse files
starkwjCopilotyewentao256mergify[bot]
authored
[Bugfix] Fix token loss in PP mode which causes degraded accuracy (vllm-project#41133)
Signed-off-by: Jing Wang <jingwang96@qq.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 22a3cbe commit 27702f6

3 files changed

Lines changed: 179 additions & 7 deletions

File tree

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,34 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
163163
)
164164

165165

166+
def _schedule_cached_requests(
167+
req_ids: list[str],
168+
num_scheduled_tokens: dict[str, int],
169+
new_token_ids: list[list[int]],
170+
num_computed_tokens: list[int],
171+
num_output_tokens: list[int],
172+
) -> SchedulerOutput:
173+
return SchedulerOutput(
174+
scheduled_new_reqs=[],
175+
scheduled_cached_reqs=CachedRequestData(
176+
req_ids=req_ids,
177+
resumed_req_ids=set(),
178+
new_token_ids=new_token_ids,
179+
all_token_ids={},
180+
new_block_ids=[None] * len(req_ids),
181+
num_computed_tokens=num_computed_tokens,
182+
num_output_tokens=num_output_tokens,
183+
),
184+
num_scheduled_tokens=num_scheduled_tokens,
185+
total_num_scheduled_tokens=sum(num_scheduled_tokens.values()),
186+
scheduled_spec_decode_tokens={},
187+
scheduled_encoder_inputs={},
188+
num_common_prefix_blocks=[],
189+
finished_req_ids=set(),
190+
free_encoder_mm_hashes=[],
191+
)
192+
193+
166194
def _is_req_scheduled(model_runner, req_id: str) -> bool:
167195
return req_id in model_runner.input_batch.req_id_to_index
168196

@@ -510,6 +538,135 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
510538
assert not _is_req_scheduled(model_runner, req_ids[1])
511539

512540

541+
def test_update_states_pp_non_async_multi_request_keeps_token_buffers_consistent(
542+
model_runner, model_runner_2, dist_init, monkeypatch
543+
):
544+
req_ids = ["req_0", "req_1"]
545+
non_last_runner = model_runner
546+
last_runner = model_runner_2
547+
non_last_runner.use_async_scheduling = False
548+
last_runner.use_async_scheduling = False
549+
550+
# Both ranks start from the same request set.
551+
monkeypatch.setattr(
552+
"vllm.v1.worker.gpu_model_runner.get_pp_group",
553+
lambda: SimpleNamespace(is_last_rank=False, world_size=2),
554+
)
555+
non_last_runner._update_states(_schedule_new_request(*req_ids))
556+
last_runner._update_states(_schedule_new_request(*req_ids))
557+
558+
sampled_by_last_rank = {req_ids[0]: 101, req_ids[1]: 201}
559+
# Emulate last-rank bookkeeping result from previous step:
560+
# sampled tokens already cached in CPU token buffers.
561+
for req_id, token_id in sampled_by_last_rank.items():
562+
req_index = last_runner.input_batch.req_id_to_index[req_id]
563+
start_idx = int(last_runner.input_batch.num_tokens_no_spec[req_index])
564+
end_idx = start_idx + 1
565+
last_runner.input_batch.token_ids_cpu[req_index, start_idx:end_idx] = [token_id]
566+
last_runner.input_batch.is_token_ids[req_index, start_idx:end_idx] = True
567+
last_runner.input_batch.num_tokens_no_spec[req_index] = end_idx
568+
last_runner.requests[req_id].output_token_ids.append(token_id)
569+
570+
scheduler_output = _schedule_cached_requests(
571+
req_ids=req_ids,
572+
num_scheduled_tokens={req_ids[0]: 1, req_ids[1]: 1},
573+
new_token_ids=[[101], [201]],
574+
num_computed_tokens=[3, 3], # prompt tokens only
575+
num_output_tokens=[1, 1],
576+
)
577+
# non-last rank appends new_token_ids in _update_states.
578+
monkeypatch.setattr(
579+
"vllm.v1.worker.gpu_model_runner.get_pp_group",
580+
lambda: SimpleNamespace(is_last_rank=False, world_size=2),
581+
)
582+
non_last_runner._update_states(scheduler_output)
583+
# last rank should keep its already-bookkept CPU buffers unchanged.
584+
monkeypatch.setattr(
585+
"vllm.v1.worker.gpu_model_runner.get_pp_group",
586+
lambda: SimpleNamespace(is_last_rank=True, world_size=2),
587+
)
588+
last_runner._update_states(scheduler_output)
589+
590+
# Verify consistency between PP ranks after _update_states.
591+
for req_id in req_ids:
592+
non_last_idx = non_last_runner.input_batch.req_id_to_index[req_id]
593+
last_idx = last_runner.input_batch.req_id_to_index[req_id]
594+
non_last_len = int(non_last_runner.input_batch.num_tokens_no_spec[non_last_idx])
595+
last_len = int(last_runner.input_batch.num_tokens_no_spec[last_idx])
596+
assert non_last_len == last_len
597+
assert (
598+
non_last_runner.input_batch.token_ids_cpu[
599+
non_last_idx, :non_last_len
600+
].tolist()
601+
== last_runner.input_batch.token_ids_cpu[last_idx, :last_len].tolist()
602+
)
603+
604+
605+
def test_update_states_pp_async_multi_request_keeps_rank_state_consistent(
606+
model_runner, model_runner_2, dist_init, monkeypatch
607+
):
608+
req_ids = ["req_0", "req_1"]
609+
non_last_runner = model_runner
610+
last_runner = model_runner_2
611+
non_last_runner.use_async_scheduling = True
612+
last_runner.use_async_scheduling = True
613+
614+
# Both ranks start from the same request set.
615+
monkeypatch.setattr(
616+
"vllm.v1.worker.gpu_model_runner.get_pp_group",
617+
lambda: SimpleNamespace(is_last_rank=False, world_size=2),
618+
)
619+
non_last_runner._update_states(_schedule_new_request(*req_ids))
620+
last_runner._update_states(_schedule_new_request(*req_ids))
621+
622+
# Simulate async previous-step sampled tokens known on both ranks.
623+
# non-last rank may receive them via PP communication; last rank has
624+
# them from local sampling/bookkeeping.
625+
sampled_by_last_rank = {req_ids[0]: 111, req_ids[1]: 222}
626+
for runner in (non_last_runner, last_runner):
627+
for req_id, token_id in sampled_by_last_rank.items():
628+
req_index = runner.input_batch.req_id_to_index[req_id]
629+
start_idx = int(runner.input_batch.num_tokens_no_spec[req_index])
630+
end_idx = start_idx + 1
631+
runner.input_batch.token_ids_cpu[req_index, start_idx:end_idx] = [token_id]
632+
runner.input_batch.is_token_ids[req_index, start_idx:end_idx] = True
633+
runner.input_batch.num_tokens_no_spec[req_index] = end_idx
634+
runner.requests[req_id].output_token_ids.append(token_id)
635+
636+
scheduler_output = _schedule_cached_requests(
637+
req_ids=req_ids,
638+
num_scheduled_tokens={req_ids[0]: 1, req_ids[1]: 1},
639+
new_token_ids=[],
640+
num_computed_tokens=[4, 4],
641+
num_output_tokens=[1, 1],
642+
)
643+
# non-last rank: async PP branch (new_token_ids empty).
644+
monkeypatch.setattr(
645+
"vllm.v1.worker.gpu_model_runner.get_pp_group",
646+
lambda: SimpleNamespace(is_last_rank=False, world_size=2),
647+
)
648+
non_last_runner._update_states(scheduler_output)
649+
# last rank: keep already-bookkept state aligned with scheduler view.
650+
monkeypatch.setattr(
651+
"vllm.v1.worker.gpu_model_runner.get_pp_group",
652+
lambda: SimpleNamespace(is_last_rank=True, world_size=2),
653+
)
654+
last_runner._update_states(scheduler_output)
655+
656+
for req_id in req_ids:
657+
non_last_idx = non_last_runner.input_batch.req_id_to_index[req_id]
658+
last_idx = last_runner.input_batch.req_id_to_index[req_id]
659+
non_last_len = int(non_last_runner.input_batch.num_tokens_no_spec[non_last_idx])
660+
last_len = int(last_runner.input_batch.num_tokens_no_spec[last_idx])
661+
assert non_last_len == last_len
662+
assert (
663+
non_last_runner.input_batch.token_ids_cpu[
664+
non_last_idx, :non_last_len
665+
].tolist()
666+
== last_runner.input_batch.token_ids_cpu[last_idx, :last_len].tolist()
667+
)
668+
669+
513670
def test_kv_cache_stride_order(monkeypatch, model_runner):
514671
# This test checks if GPUModelRunner initializes correctly when an attention
515672
# backend enforces a non-default KV cache stride order.

vllm/v1/worker/gpu_input_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def update_req_spec_token_ids(
504504
start_index = self.num_tokens_no_spec[req_index]
505505
end_token_index = start_index + num_spec_tokens
506506
self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
507+
self.is_token_ids[req_index, start_index:end_token_index] = True
507508
cur_spec_token_ids.extend(spec_token_ids)
508509

509510
def remove_request(self, req_id: str) -> int | None:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,13 +1347,27 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None
13471347
# For the last rank, we don't need to update the token_ids_cpu
13481348
# because the sampled tokens are already cached.
13491349
if not is_last_rank:
1350-
# Add new_token_ids to token_ids_cpu.
1351-
start_token_index = num_computed_tokens
1352-
end_token_index = num_computed_tokens + len(new_token_ids)
1353-
self.input_batch.token_ids_cpu[
1354-
req_index, start_token_index:end_token_index
1355-
] = new_token_ids
1356-
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
1350+
start_token_index = self.input_batch.num_tokens_no_spec[req_index]
1351+
# For chunked prefill, num_computed_tokens may less
1352+
# than num_tokens_no_spec.
1353+
# Async scheduled PP: no new_token_ids, advance num_tokens_no_spec
1354+
# according to num_computed_tokens.
1355+
end_token_index = max(
1356+
start_token_index,
1357+
num_computed_tokens + len(new_token_ids),
1358+
)
1359+
if end_token_index > start_token_index:
1360+
if new_token_ids:
1361+
# Add new_token_ids to token_ids_cpu.
1362+
num_new_tokens = end_token_index - start_token_index
1363+
tokens_to_append = new_token_ids[-num_new_tokens:]
1364+
self.input_batch.token_ids_cpu[
1365+
req_index, start_token_index:end_token_index
1366+
] = tokens_to_append
1367+
self.input_batch.is_token_ids[
1368+
req_index, start_token_index:end_token_index
1369+
] = True
1370+
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
13571371

13581372
# Add spec_token_ids to token_ids_cpu.
13591373
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)

0 commit comments

Comments
 (0)