@@ -134,9 +134,16 @@ def _capture_model_forward_graph(
134134 num_tokens : int ,
135135 * ,
136136 allow_compile : bool = False ,
137+ inputs_embeds : torch .Tensor | None = None ,
137138 ) -> torch .cuda .CUDAGraph :
138139 def run_once () -> None :
139- outputs [:num_tokens ] = self .model (input_ids [:num_tokens ], positions [:num_tokens ])
140+ if inputs_embeds is None :
141+ outputs [:num_tokens ] = self .model (input_ids [:num_tokens ], positions [:num_tokens ])
142+ else :
143+ outputs [:num_tokens ] = self .model .forward_inputs_embeds (
144+ inputs_embeds [:num_tokens ],
145+ positions [:num_tokens ],
146+ )
140147
141148 stream = self ._get_graph_capture_stream ()
142149 pool = self ._get_graph_pool ()
@@ -291,6 +298,16 @@ def _model_hidden_dtype(self) -> torch.dtype:
291298 except StopIteration :
292299 return torch .get_default_dtype ()
293300
301+ def _use_cudagraph_inputs_embeds (self ) -> bool :
302+ if not bool (getattr (self .config , "enable_cudagraph_inputs_embeds" , False )):
303+ return False
304+ if getattr (self .config , "model_name" , None ) not in {"llada2" , "llada2_moe" , "llada2_mini" }:
305+ return False
306+ return (
307+ hasattr (self .model , "forward_inputs_embeds" )
308+ and hasattr (getattr (self .model , "model" , None ), "word_embeddings" )
309+ )
310+
294311 def _ensure_runtime_static_buffers (
295312 self ,
296313 * ,
@@ -456,6 +473,12 @@ def _capture_prefill_cudagraph(self, bucket_len: int):
456473 device = self ._cuda_graph_device ()
457474
458475 input_ids = torch .zeros (bucket_len , dtype = torch .int64 , device = device )
476+ use_inputs_embeds = self ._use_cudagraph_inputs_embeds ()
477+ inputs_embeds = (
478+ torch .zeros (bucket_len , hf_config .hidden_size , dtype = self ._model_hidden_dtype (), device = device )
479+ if use_inputs_embeds
480+ else None
481+ )
459482 positions = torch .zeros (bucket_len , dtype = torch .int64 , device = device )
460483 slot_mapping = torch .full ((bucket_len ,), - 1 , dtype = torch .int32 , device = device )
461484 context_lens = torch .zeros (req_capacity , dtype = torch .int32 , device = device )
@@ -506,10 +529,18 @@ def _capture_prefill_cudagraph(self, bucket_len: int):
506529 padded_prefix_lens = padded_prefix_lens ,
507530 outputs = outputs ,
508531 )
532+ if inputs_embeds is not None :
533+ graph_vars ["inputs_embeds" ] = inputs_embeds
509534 graph_vars .update (self ._prefill_graph_extra_vars (bucket_len , device ))
510535 self ._init_prefill_graph_extra_metadata (attn_metadata , graph_vars , bucket_len )
511536
512- graph = self ._capture_model_forward_graph (input_ids , positions , outputs , bucket_len )
537+ graph = self ._capture_model_forward_graph (
538+ input_ids ,
539+ positions ,
540+ outputs ,
541+ bucket_len ,
542+ inputs_embeds = inputs_embeds ,
543+ )
513544 if self .graph_pool is None :
514545 self .graph_pool = graph .pool ()
515546 torch .cuda .synchronize ()
@@ -528,14 +559,16 @@ def _copy_common_graph_inputs(
528559 num_reqs : int ,
529560 ) -> None :
530561 for key , value in graph_vars .items ():
531- if key == "outputs" :
562+ if key in ( "outputs" , "inputs_embeds" ) :
532563 continue
533564 if key in ("slot_mapping" , "page_tables" ):
534565 value .fill_ (- 1 )
535566 else :
536567 value .zero_ ()
537568
538569 graph_vars ["input_ids" ][:num_tokens ] = input_ids
570+ if "inputs_embeds" in graph_vars :
571+ graph_vars ["inputs_embeds" ][:num_tokens ] = self .model .model .word_embeddings (input_ids )
539572 graph_vars ["positions" ][:num_tokens ] = positions
540573 graph_vars ["slot_mapping" ][:num_tokens ] = attn_metadata .slot_mapping
541574 graph_vars ["context_lens" ][:num_reqs ] = attn_metadata .context_lens
@@ -767,6 +800,12 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
767800 device = self ._cuda_graph_device ()
768801
769802 input_ids = torch .zeros (max_num_tokens , dtype = torch .int64 , device = device )
803+ use_inputs_embeds = self ._use_cudagraph_inputs_embeds ()
804+ inputs_embeds = (
805+ torch .zeros (max_num_tokens , hf_config .hidden_size , dtype = self ._model_hidden_dtype (), device = device )
806+ if use_inputs_embeds
807+ else None
808+ )
770809 positions = torch .zeros (max_num_tokens , dtype = torch .int64 , device = device )
771810 slot_mapping = torch .full ((max_num_tokens ,), - 1 , dtype = torch .int32 , device = device )
772811 context_lens = torch .zeros (max_num_seqs , dtype = torch .int32 , device = device )
@@ -819,7 +858,14 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
819858 padded_prefix_lens = padded_prefix_lens [:num_seqs ],
820859 )
821860
822- graph = self ._capture_model_forward_graph (input_ids , positions , outputs , num_tokens , allow_compile = True )
861+ graph = self ._capture_model_forward_graph (
862+ input_ids ,
863+ positions ,
864+ outputs ,
865+ num_tokens ,
866+ allow_compile = True ,
867+ inputs_embeds = inputs_embeds [:num_tokens ] if inputs_embeds is not None else None ,
868+ )
823869 if self .graph_pool is None :
824870 self .graph_pool = graph .pool ()
825871 self .graphs [num_tokens ] = graph
@@ -840,4 +886,6 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
840886 padded_prefix_lens = padded_prefix_lens ,
841887 outputs = outputs ,
842888 )
889+ if inputs_embeds is not None :
890+ self .graph_vars ["inputs_embeds" ] = inputs_embeds
843891 reset_warming_up ()
0 commit comments