11import time
22from dataclasses import dataclass
3- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type , Union
3+ from typing import (TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple ,
4+ Type , Union )
45from unittest .mock import patch
56
67import numpy as np
@@ -51,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
5152 best_of : List [int ]
5253 seq_groups : List [List [int ]]
5354 virtual_engine : int = 0
55+ async_callback : Optional [Callable ] = None
5456
5557 def as_broadcastable_tensor_dict (
5658 self ) -> Dict [str , Union [int , torch .Tensor ]]:
@@ -562,6 +564,8 @@ def _execute_model(*args):
562564 model_input .attn_metadata , model_input .input_lens [i :i + 1 ],
563565 model_input .t [i :i + 1 ], model_input .p [i :i + 1 ],
564566 model_input .num_samples , kv_caches )
567+ if i == 0 and model_input .async_callback is not None :
568+ model_input .async_callback ()
565569 # Retrieve the outputs to CPU.
566570 next_token_ids += output_token_ids .cpu ().tolist ()
567571 start_idx = end_idx
@@ -572,6 +576,8 @@ def _execute_model(*args):
572576 model_input .attn_metadata , model_input .input_lens ,
573577 model_input .t , model_input .p , model_input .num_samples ,
574578 kv_caches )
579+ if model_input .async_callback is not None :
580+ model_input .async_callback ()
575581 # Retrieve the outputs to CPU.
576582 next_token_ids = output_token_ids .cpu ().tolist ()
577583
0 commit comments