Skip to content

Commit 80c7b08

Browse files
authored
[TPU] Async output processing for TPU (#8011)
1 parent 428dd14 commit 80c7b08

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

vllm/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
347347
self.use_async_output_proc = False
348348
return
349349

350-
if device_config.device_type != "cuda":
350+
if device_config.device_type not in ("cuda", "tpu"):
351351
logger.warning(
352-
"Async output processing is only supported for CUDA."
353-
" Disabling it for other platforms.")
352+
"Async output processing is only supported for CUDA or TPU. "
353+
"Disabling it for other platforms.")
354354
self.use_async_output_proc = False
355355
return
356356

vllm/worker/tpu_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
from dataclasses import dataclass
3-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
3+
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
4+
Type, Union)
45
from unittest.mock import patch
56

67
import numpy as np
@@ -51,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
5152
best_of: List[int]
5253
seq_groups: List[List[int]]
5354
virtual_engine: int = 0
55+
async_callback: Optional[Callable] = None
5456

5557
def as_broadcastable_tensor_dict(
5658
self) -> Dict[str, Union[int, torch.Tensor]]:
@@ -562,6 +564,8 @@ def _execute_model(*args):
562564
model_input.attn_metadata, model_input.input_lens[i:i + 1],
563565
model_input.t[i:i + 1], model_input.p[i:i + 1],
564566
model_input.num_samples, kv_caches)
567+
if i == 0 and model_input.async_callback is not None:
568+
model_input.async_callback()
565569
# Retrieve the outputs to CPU.
566570
next_token_ids += output_token_ids.cpu().tolist()
567571
start_idx = end_idx
@@ -572,6 +576,8 @@ def _execute_model(*args):
572576
model_input.attn_metadata, model_input.input_lens,
573577
model_input.t, model_input.p, model_input.num_samples,
574578
kv_caches)
579+
if model_input.async_callback is not None:
580+
model_input.async_callback()
575581
# Retrieve the outputs to CPU.
576582
next_token_ids = output_token_ids.cpu().tolist()
577583

0 commit comments

Comments
 (0)