|
86 | 86 | from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext |
87 | 87 | from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch |
88 | 88 | from vllm.v1.worker.kv_connector_model_runner_mixin import ( |
89 | | - KVConnectorModelRunnerMixin, KVConnectorOutput) |
| 89 | + KVConnectorModelRunnerMixin) |
90 | 90 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
91 | 91 |
|
92 | 92 | from .utils import (AttentionGroup, MultiModalBudget, |
@@ -196,6 +196,14 @@ def __init__( |
196 | 196 | self.max_num_tokens = scheduler_config.max_num_batched_tokens |
197 | 197 | self.max_num_reqs = scheduler_config.max_num_seqs |
198 | 198 |
|
| 199 | + # Broadcast PP output for external_launcher (torchrun) |
| 200 | + # to make sure we are synced across pp ranks |
| 201 | + # TODO: Support overlapping mirco-batches |
| 202 | + # https://github.com/vllm-project/vllm/issues/18019 |
| 203 | + self.broadcast_pp_output = ( |
| 204 | + self.parallel_config.distributed_executor_backend |
| 205 | + == "external_launcher" and len(get_pp_group().ranks) > 0) |
| 206 | + |
199 | 207 | # Model-related. |
200 | 208 | self.num_query_heads = model_config.get_num_attention_heads( |
201 | 209 | parallel_config) |
@@ -1701,7 +1709,6 @@ def _pool( |
1701 | 1709 | hidden_states: torch.Tensor, |
1702 | 1710 | num_scheduled_tokens: int, |
1703 | 1711 | num_scheduled_tokens_np: np.ndarray, |
1704 | | - kv_connector_output: Optional[KVConnectorOutput], |
1705 | 1712 | ) -> ModelRunnerOutput: |
1706 | 1713 | assert self.input_batch.num_reqs ==\ |
1707 | 1714 | len(self.input_batch.pooling_params), \ |
@@ -1732,7 +1739,6 @@ def _pool( |
1732 | 1739 | logprobs=None, |
1733 | 1740 | prompt_logprobs_dict={}, |
1734 | 1741 | pooler_output=pooler_output, |
1735 | | - kv_connector_output=kv_connector_output, |
1736 | 1742 | ) |
1737 | 1743 |
|
1738 | 1744 | def _preprocess( |
@@ -2073,39 +2079,47 @@ def execute_model( |
2073 | 2079 |
|
2074 | 2080 | with record_function_or_nullcontext("Postprocess"): |
2075 | 2081 | if self.use_aux_hidden_state_outputs: |
| 2082 | + # True when EAGLE 3 is used. |
2076 | 2083 | hidden_states, aux_hidden_states = model_output |
2077 | 2084 | else: |
| 2085 | + # Common case. |
2078 | 2086 | hidden_states = model_output |
2079 | 2087 | aux_hidden_states = None |
2080 | 2088 |
|
2081 | | - # Broadcast PP output for external_launcher (torchrun) |
2082 | | - # to make sure we are synced across pp ranks |
2083 | | - # TODO: Support overlapping mirco-batches |
2084 | | - # https://github.com/vllm-project/vllm/issues/18019 |
2085 | | - broadcast_pp_output = \ |
2086 | | - self.parallel_config.distributed_executor_backend \ |
2087 | | - == "external_launcher" and len(get_pp_group().ranks) > 0 |
2088 | | - if not get_pp_group().is_last_rank: |
2089 | | - # For mid-pipeline stages, return the hidden states. |
2090 | | - assert isinstance(hidden_states, IntermediateTensors) |
2091 | | - if not broadcast_pp_output: |
| 2089 | + if not self.broadcast_pp_output: |
| 2090 | + # Common case. |
| 2091 | + if not get_pp_group().is_last_rank: |
| 2092 | + # Return the intermediate tensors. |
| 2093 | + assert isinstance(hidden_states, IntermediateTensors) |
2092 | 2094 | hidden_states.kv_connector_output = kv_connector_output |
2093 | 2095 | return hidden_states |
2094 | | - get_pp_group().send_tensor_dict( |
2095 | | - hidden_states.tensors, all_gather_group=get_tp_group()) |
2096 | | - logits = None |
2097 | | - else: |
| 2096 | + |
2098 | 2097 | if self.is_pooling_model: |
2099 | | - return self._pool(hidden_states, num_scheduled_tokens, |
2100 | | - num_scheduled_tokens_np, |
2101 | | - kv_connector_output) |
| 2098 | + # Return the pooling output. |
| 2099 | + output = self._pool(hidden_states, num_scheduled_tokens, |
| 2100 | + num_scheduled_tokens_np) |
| 2101 | + output.kv_connector_output = kv_connector_output |
| 2102 | + return output |
2102 | 2103 |
|
2103 | 2104 | sample_hidden_states = hidden_states[logits_indices] |
2104 | 2105 | logits = self.model.compute_logits(sample_hidden_states, None) |
2105 | | - if broadcast_pp_output: |
2106 | | - model_output_broadcast_data = { |
2107 | | - "logits": logits.contiguous(), |
2108 | | - } if logits is not None else {} |
| 2106 | + else: |
| 2107 | + # Rare case. |
| 2108 | + assert not self.is_pooling_model |
| 2109 | + |
| 2110 | + if not get_pp_group().is_last_rank: |
| 2111 | + get_pp_group().send_tensor_dict( |
| 2112 | + hidden_states.tensors, all_gather_group=get_tp_group()) |
| 2113 | + logits = None |
| 2114 | + else: |
| 2115 | + sample_hidden_states = hidden_states[logits_indices] |
| 2116 | + logits = self.model.compute_logits(sample_hidden_states, |
| 2117 | + None) |
| 2118 | + |
| 2119 | + model_output_broadcast_data = {} |
| 2120 | + if logits is not None: |
| 2121 | + model_output_broadcast_data["logits"] = logits.contiguous() |
| 2122 | + |
2109 | 2123 | model_output_broadcast_data = get_pp_group( |
2110 | 2124 | ).broadcast_tensor_dict(model_output_broadcast_data, |
2111 | 2125 | src=len(get_pp_group().ranks) - 1) |
|
0 commit comments