@@ -163,6 +163,34 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
163163 )
164164
165165
166+ def _schedule_cached_requests (
167+ req_ids : list [str ],
168+ num_scheduled_tokens : dict [str , int ],
169+ new_token_ids : list [list [int ]],
170+ num_computed_tokens : list [int ],
171+ num_output_tokens : list [int ],
172+ ) -> SchedulerOutput :
173+ return SchedulerOutput (
174+ scheduled_new_reqs = [],
175+ scheduled_cached_reqs = CachedRequestData (
176+ req_ids = req_ids ,
177+ resumed_req_ids = set (),
178+ new_token_ids = new_token_ids ,
179+ all_token_ids = {},
180+ new_block_ids = [None ] * len (req_ids ),
181+ num_computed_tokens = num_computed_tokens ,
182+ num_output_tokens = num_output_tokens ,
183+ ),
184+ num_scheduled_tokens = num_scheduled_tokens ,
185+ total_num_scheduled_tokens = sum (num_scheduled_tokens .values ()),
186+ scheduled_spec_decode_tokens = {},
187+ scheduled_encoder_inputs = {},
188+ num_common_prefix_blocks = [],
189+ finished_req_ids = set (),
190+ free_encoder_mm_hashes = [],
191+ )
192+
193+
166194def _is_req_scheduled (model_runner , req_id : str ) -> bool :
167195 return req_id in model_runner .input_batch .req_id_to_index
168196
@@ -510,6 +538,135 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
510538 assert not _is_req_scheduled (model_runner , req_ids [1 ])
511539
512540
541+ def test_update_states_pp_non_async_multi_request_keeps_token_buffers_consistent (
542+ model_runner , model_runner_2 , dist_init , monkeypatch
543+ ):
544+ req_ids = ["req_0" , "req_1" ]
545+ non_last_runner = model_runner
546+ last_runner = model_runner_2
547+ non_last_runner .use_async_scheduling = False
548+ last_runner .use_async_scheduling = False
549+
550+ # Both ranks start from the same request set.
551+ monkeypatch .setattr (
552+ "vllm.v1.worker.gpu_model_runner.get_pp_group" ,
553+ lambda : SimpleNamespace (is_last_rank = False , world_size = 2 ),
554+ )
555+ non_last_runner ._update_states (_schedule_new_request (* req_ids ))
556+ last_runner ._update_states (_schedule_new_request (* req_ids ))
557+
558+ sampled_by_last_rank = {req_ids [0 ]: 101 , req_ids [1 ]: 201 }
559+ # Emulate last-rank bookkeeping result from previous step:
560+ # sampled tokens already cached in CPU token buffers.
561+ for req_id , token_id in sampled_by_last_rank .items ():
562+ req_index = last_runner .input_batch .req_id_to_index [req_id ]
563+ start_idx = int (last_runner .input_batch .num_tokens_no_spec [req_index ])
564+ end_idx = start_idx + 1
565+ last_runner .input_batch .token_ids_cpu [req_index , start_idx :end_idx ] = [token_id ]
566+ last_runner .input_batch .is_token_ids [req_index , start_idx :end_idx ] = True
567+ last_runner .input_batch .num_tokens_no_spec [req_index ] = end_idx
568+ last_runner .requests [req_id ].output_token_ids .append (token_id )
569+
570+ scheduler_output = _schedule_cached_requests (
571+ req_ids = req_ids ,
572+ num_scheduled_tokens = {req_ids [0 ]: 1 , req_ids [1 ]: 1 },
573+ new_token_ids = [[101 ], [201 ]],
574+ num_computed_tokens = [3 , 3 ], # prompt tokens only
575+ num_output_tokens = [1 , 1 ],
576+ )
577+ # non-last rank appends new_token_ids in _update_states.
578+ monkeypatch .setattr (
579+ "vllm.v1.worker.gpu_model_runner.get_pp_group" ,
580+ lambda : SimpleNamespace (is_last_rank = False , world_size = 2 ),
581+ )
582+ non_last_runner ._update_states (scheduler_output )
583+ # last rank should keep its already-bookkept CPU buffers unchanged.
584+ monkeypatch .setattr (
585+ "vllm.v1.worker.gpu_model_runner.get_pp_group" ,
586+ lambda : SimpleNamespace (is_last_rank = True , world_size = 2 ),
587+ )
588+ last_runner ._update_states (scheduler_output )
589+
590+ # Verify consistency between PP ranks after _update_states.
591+ for req_id in req_ids :
592+ non_last_idx = non_last_runner .input_batch .req_id_to_index [req_id ]
593+ last_idx = last_runner .input_batch .req_id_to_index [req_id ]
594+ non_last_len = int (non_last_runner .input_batch .num_tokens_no_spec [non_last_idx ])
595+ last_len = int (last_runner .input_batch .num_tokens_no_spec [last_idx ])
596+ assert non_last_len == last_len
597+ assert (
598+ non_last_runner .input_batch .token_ids_cpu [
599+ non_last_idx , :non_last_len
600+ ].tolist ()
601+ == last_runner .input_batch .token_ids_cpu [last_idx , :last_len ].tolist ()
602+ )
603+
604+
605+ def test_update_states_pp_async_multi_request_keeps_rank_state_consistent (
606+ model_runner , model_runner_2 , dist_init , monkeypatch
607+ ):
608+ req_ids = ["req_0" , "req_1" ]
609+ non_last_runner = model_runner
610+ last_runner = model_runner_2
611+ non_last_runner .use_async_scheduling = True
612+ last_runner .use_async_scheduling = True
613+
614+ # Both ranks start from the same request set.
615+ monkeypatch .setattr (
616+ "vllm.v1.worker.gpu_model_runner.get_pp_group" ,
617+ lambda : SimpleNamespace (is_last_rank = False , world_size = 2 ),
618+ )
619+ non_last_runner ._update_states (_schedule_new_request (* req_ids ))
620+ last_runner ._update_states (_schedule_new_request (* req_ids ))
621+
622+ # Simulate async previous-step sampled tokens known on both ranks.
623+ # non-last rank may receive them via PP communication; last rank has
624+ # them from local sampling/bookkeeping.
625+ sampled_by_last_rank = {req_ids [0 ]: 111 , req_ids [1 ]: 222 }
626+ for runner in (non_last_runner , last_runner ):
627+ for req_id , token_id in sampled_by_last_rank .items ():
628+ req_index = runner .input_batch .req_id_to_index [req_id ]
629+ start_idx = int (runner .input_batch .num_tokens_no_spec [req_index ])
630+ end_idx = start_idx + 1
631+ runner .input_batch .token_ids_cpu [req_index , start_idx :end_idx ] = [token_id ]
632+ runner .input_batch .is_token_ids [req_index , start_idx :end_idx ] = True
633+ runner .input_batch .num_tokens_no_spec [req_index ] = end_idx
634+ runner .requests [req_id ].output_token_ids .append (token_id )
635+
636+ scheduler_output = _schedule_cached_requests (
637+ req_ids = req_ids ,
638+ num_scheduled_tokens = {req_ids [0 ]: 1 , req_ids [1 ]: 1 },
639+ new_token_ids = [],
640+ num_computed_tokens = [4 , 4 ],
641+ num_output_tokens = [1 , 1 ],
642+ )
643+ # non-last rank: async PP branch (new_token_ids empty).
644+ monkeypatch .setattr (
645+ "vllm.v1.worker.gpu_model_runner.get_pp_group" ,
646+ lambda : SimpleNamespace (is_last_rank = False , world_size = 2 ),
647+ )
648+ non_last_runner ._update_states (scheduler_output )
649+ # last rank: keep already-bookkept state aligned with scheduler view.
650+ monkeypatch .setattr (
651+ "vllm.v1.worker.gpu_model_runner.get_pp_group" ,
652+ lambda : SimpleNamespace (is_last_rank = True , world_size = 2 ),
653+ )
654+ last_runner ._update_states (scheduler_output )
655+
656+ for req_id in req_ids :
657+ non_last_idx = non_last_runner .input_batch .req_id_to_index [req_id ]
658+ last_idx = last_runner .input_batch .req_id_to_index [req_id ]
659+ non_last_len = int (non_last_runner .input_batch .num_tokens_no_spec [non_last_idx ])
660+ last_len = int (last_runner .input_batch .num_tokens_no_spec [last_idx ])
661+ assert non_last_len == last_len
662+ assert (
663+ non_last_runner .input_batch .token_ids_cpu [
664+ non_last_idx , :non_last_len
665+ ].tolist ()
666+ == last_runner .input_batch .token_ids_cpu [last_idx , :last_len ].tolist ()
667+ )
668+
669+
513670def test_kv_cache_stride_order (monkeypatch , model_runner ):
514671 # This test checks if GPUModelRunner initializes correctly when an attention
515672 # backend enforces a non-default KV cache stride order.
0 commit comments