Skip to content

Commit bdac9ed

Browse files
committed
feat: add support for inputs_embeds in CUDA graph execution
- Introduced `enable_cudagraph_inputs_embeds` configuration option to allow capturing CUDA graphs from `inputs_embeds` instead of `input_ids`, enhancing flexibility in model execution. - Updated `LLaDA2Model` and `MultiBlockModelRunnerTemplate` to support the new inputs_embeds functionality, improving performance for specific use cases. - Modified argument parsing and configuration loading to accommodate the new inputs_embeds settings, ensuring seamless integration with existing benchmarks.
1 parent 22a626f commit bdac9ed

8 files changed

Lines changed: 92 additions & 8 deletions

File tree

diffulex/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class Config:
7777
attn_impl: str = "triton" # "triton" or "naive"
7878
enable_prefill_cudagraph: bool = True
7979
enable_full_static_runner: bool = True
80+
enable_cudagraph_inputs_embeds: bool = False
8081
prefill_cudagraph_max_len: int = 0
8182
enable_torch_compile: bool = True
8283
enable_cudagraph_torch_compile: bool = False

diffulex/layer/embed_head.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def _linear_into_workspace(self, x: torch.Tensor) -> torch.Tensor:
120120
logits.add_(self.bias)
121121
return logits
122122

123-
def forward(self, x: torch.Tensor):
123+
def forward(self, x: torch.Tensor, gather: bool = True):
124124
logits = self._linear_into_workspace(x)
125-
if self.tp_size > 1:
125+
if gather and self.tp_size > 1:
126126
if LM_HEAD_FP32_GATHER:
127127
logits_dtype = logits.dtype
128128
logits = _tp_gather_to_rank0(logits.to(torch.float32), self.tp_group, self.tp_size, self.tp_rank)

diffulex/model/llada2.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,15 @@ def forward(
363363
mask: torch.Tensor | None = None,
364364
) -> torch.Tensor:
365365
hidden_states = self.word_embeddings(input_ids)
366+
return self.forward_inputs_embeds(hidden_states, positions, mask)
367+
368+
def forward_inputs_embeds(
369+
self,
370+
inputs_embeds: torch.Tensor,
371+
positions: torch.Tensor,
372+
mask: torch.Tensor | None = None,
373+
) -> torch.Tensor:
374+
hidden_states = inputs_embeds
366375
hidden_states = self._maybe_apply_token_merging(hidden_states)
367376
for layer in self.layers:
368377
hidden_states = layer(positions, hidden_states, mask)
@@ -497,8 +506,16 @@ def forward(
497506
) -> torch.Tensor:
498507
return self.model(input_ids, positions, mask)
499508

500-
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
501-
return self.lm_head(hidden_states)
509+
def forward_inputs_embeds(
510+
self,
511+
inputs_embeds: torch.Tensor,
512+
positions: torch.Tensor,
513+
mask: torch.Tensor | None = None,
514+
) -> torch.Tensor:
515+
return self.model.forward_inputs_embeds(inputs_embeds, positions, mask)
516+
517+
def compute_logits(self, hidden_states: torch.Tensor, gather: bool = True) -> torch.Tensor:
518+
return self.lm_head(hidden_states, gather=gather)
502519

503520

504521
SparseMoEBlock = FusedMoE

diffulex/server/args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ServerArgs:
3838
max_model_len: int = 2048
3939
enable_prefill_cudagraph: bool = True
4040
enable_full_static_runner: bool = True
41+
enable_cudagraph_inputs_embeds: bool = False
4142
prefill_cudagraph_max_len: int = 0
4243
enable_torch_compile: bool = True
4344
enable_cudagraph_torch_compile: bool = False
@@ -81,6 +82,7 @@ def engine_kwargs(self) -> dict:
8182
"max_model_len": self.max_model_len,
8283
"enable_prefill_cudagraph": self.enable_prefill_cudagraph,
8384
"enable_full_static_runner": self.enable_full_static_runner,
85+
"enable_cudagraph_inputs_embeds": self.enable_cudagraph_inputs_embeds,
8486
"prefill_cudagraph_max_len": self.prefill_cudagraph_max_len,
8587
"enable_torch_compile": self.enable_torch_compile,
8688
"enable_cudagraph_torch_compile": self.enable_cudagraph_torch_compile,
@@ -137,6 +139,7 @@ def build_arg_parser() -> argparse.ArgumentParser:
137139
parser.add_argument("--max-model-len", type=int, default=2048)
138140
parser.add_argument("--disable-prefill-cudagraph", action="store_true")
139141
parser.add_argument("--disable-full-static-runner", action="store_true")
142+
parser.add_argument("--enable-cudagraph-inputs-embeds", action="store_true")
140143
parser.add_argument("--prefill-cudagraph-max-len", type=int, default=0)
141144
parser.add_argument("--disable-torch-compile", action="store_true")
142145
parser.add_argument("--enable-cudagraph-torch-compile", action="store_true")
@@ -189,6 +192,7 @@ def parse_args(argv: Sequence[str] | None = None) -> ServerArgs:
189192
max_model_len=ns.max_model_len,
190193
enable_prefill_cudagraph=not ns.disable_prefill_cudagraph,
191194
enable_full_static_runner=not ns.disable_full_static_runner,
195+
enable_cudagraph_inputs_embeds=ns.enable_cudagraph_inputs_embeds,
192196
prefill_cudagraph_max_len=ns.prefill_cudagraph_max_len,
193197
enable_torch_compile=not ns.disable_torch_compile,
194198
enable_cudagraph_torch_compile=ns.enable_cudagraph_torch_compile,

diffulex/strategy_template/multi_block/engine/model_runner.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,16 @@ def _capture_model_forward_graph(
134134
num_tokens: int,
135135
*,
136136
allow_compile: bool = False,
137+
inputs_embeds: torch.Tensor | None = None,
137138
) -> torch.cuda.CUDAGraph:
138139
def run_once() -> None:
139-
outputs[:num_tokens] = self.model(input_ids[:num_tokens], positions[:num_tokens])
140+
if inputs_embeds is None:
141+
outputs[:num_tokens] = self.model(input_ids[:num_tokens], positions[:num_tokens])
142+
else:
143+
outputs[:num_tokens] = self.model.forward_inputs_embeds(
144+
inputs_embeds[:num_tokens],
145+
positions[:num_tokens],
146+
)
140147

141148
stream = self._get_graph_capture_stream()
142149
pool = self._get_graph_pool()
@@ -291,6 +298,16 @@ def _model_hidden_dtype(self) -> torch.dtype:
291298
except StopIteration:
292299
return torch.get_default_dtype()
293300

301+
def _use_cudagraph_inputs_embeds(self) -> bool:
302+
if not bool(getattr(self.config, "enable_cudagraph_inputs_embeds", False)):
303+
return False
304+
if getattr(self.config, "model_name", None) not in {"llada2", "llada2_moe", "llada2_mini"}:
305+
return False
306+
return (
307+
hasattr(self.model, "forward_inputs_embeds")
308+
and hasattr(getattr(self.model, "model", None), "word_embeddings")
309+
)
310+
294311
def _ensure_runtime_static_buffers(
295312
self,
296313
*,
@@ -456,6 +473,12 @@ def _capture_prefill_cudagraph(self, bucket_len: int):
456473
device = self._cuda_graph_device()
457474

458475
input_ids = torch.zeros(bucket_len, dtype=torch.int64, device=device)
476+
use_inputs_embeds = self._use_cudagraph_inputs_embeds()
477+
inputs_embeds = (
478+
torch.zeros(bucket_len, hf_config.hidden_size, dtype=self._model_hidden_dtype(), device=device)
479+
if use_inputs_embeds
480+
else None
481+
)
459482
positions = torch.zeros(bucket_len, dtype=torch.int64, device=device)
460483
slot_mapping = torch.full((bucket_len,), -1, dtype=torch.int32, device=device)
461484
context_lens = torch.zeros(req_capacity, dtype=torch.int32, device=device)
@@ -506,10 +529,18 @@ def _capture_prefill_cudagraph(self, bucket_len: int):
506529
padded_prefix_lens=padded_prefix_lens,
507530
outputs=outputs,
508531
)
532+
if inputs_embeds is not None:
533+
graph_vars["inputs_embeds"] = inputs_embeds
509534
graph_vars.update(self._prefill_graph_extra_vars(bucket_len, device))
510535
self._init_prefill_graph_extra_metadata(attn_metadata, graph_vars, bucket_len)
511536

512-
graph = self._capture_model_forward_graph(input_ids, positions, outputs, bucket_len)
537+
graph = self._capture_model_forward_graph(
538+
input_ids,
539+
positions,
540+
outputs,
541+
bucket_len,
542+
inputs_embeds=inputs_embeds,
543+
)
513544
if self.graph_pool is None:
514545
self.graph_pool = graph.pool()
515546
torch.cuda.synchronize()
@@ -528,14 +559,16 @@ def _copy_common_graph_inputs(
528559
num_reqs: int,
529560
) -> None:
530561
for key, value in graph_vars.items():
531-
if key == "outputs":
562+
if key in ("outputs", "inputs_embeds"):
532563
continue
533564
if key in ("slot_mapping", "page_tables"):
534565
value.fill_(-1)
535566
else:
536567
value.zero_()
537568

538569
graph_vars["input_ids"][:num_tokens] = input_ids
570+
if "inputs_embeds" in graph_vars:
571+
graph_vars["inputs_embeds"][:num_tokens] = self.model.model.word_embeddings(input_ids)
539572
graph_vars["positions"][:num_tokens] = positions
540573
graph_vars["slot_mapping"][:num_tokens] = attn_metadata.slot_mapping
541574
graph_vars["context_lens"][:num_reqs] = attn_metadata.context_lens
@@ -767,6 +800,12 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
767800
device = self._cuda_graph_device()
768801

769802
input_ids = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
803+
use_inputs_embeds = self._use_cudagraph_inputs_embeds()
804+
inputs_embeds = (
805+
torch.zeros(max_num_tokens, hf_config.hidden_size, dtype=self._model_hidden_dtype(), device=device)
806+
if use_inputs_embeds
807+
else None
808+
)
770809
positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
771810
slot_mapping = torch.full((max_num_tokens,), -1, dtype=torch.int32, device=device)
772811
context_lens = torch.zeros(max_num_seqs, dtype=torch.int32, device=device)
@@ -819,7 +858,14 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
819858
padded_prefix_lens=padded_prefix_lens[:num_seqs],
820859
)
821860

822-
graph = self._capture_model_forward_graph(input_ids, positions, outputs, num_tokens, allow_compile=True)
861+
graph = self._capture_model_forward_graph(
862+
input_ids,
863+
positions,
864+
outputs,
865+
num_tokens,
866+
allow_compile=True,
867+
inputs_embeds=inputs_embeds[:num_tokens] if inputs_embeds is not None else None,
868+
)
823869
if self.graph_pool is None:
824870
self.graph_pool = graph.pool()
825871
self.graphs[num_tokens] = graph
@@ -840,4 +886,6 @@ def capture_cudagraph_multi_block(self: ModelRunnerBase):
840886
padded_prefix_lens=padded_prefix_lens,
841887
outputs=outputs,
842888
)
889+
if inputs_embeds is not None:
890+
self.graph_vars["inputs_embeds"] = inputs_embeds
843891
reset_warming_up()

diffulex_bench/arg_parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,12 @@ def create_argument_parser() -> argparse.ArgumentParser:
421421
default=None,
422422
help="Use the full-static CUDA graph runner for supported multi-block forward passes",
423423
)
424+
parser.add_argument(
425+
"--enable-cudagraph-inputs-embeds",
426+
action=argparse.BooleanOptionalAction,
427+
default=None,
428+
help="For LLaDA2 only: capture CUDA graphs from inputs_embeds instead of input_ids",
429+
)
424430
parser.add_argument(
425431
"--prefill-cudagraph-max-len",
426432
type=int,

diffulex_bench/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class EngineConfig:
128128
max_num_reqs: int = 128
129129
enable_prefill_cudagraph: bool = True
130130
enable_full_static_runner: bool = True
131+
enable_cudagraph_inputs_embeds: bool = False
131132
prefill_cudagraph_max_len: int = 0
132133
enable_torch_compile: bool = True
133134
enable_cudagraph_torch_compile: bool = False

diffulex_bench/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,8 @@ def apply_engine_arg_overrides(engine: EngineConfig) -> None:
458458
config.engine.enable_prefill_cudagraph = bool(args.enable_prefill_cudagraph)
459459
if getattr(args, "enable_full_static_runner", None) is not None:
460460
config.engine.enable_full_static_runner = bool(args.enable_full_static_runner)
461+
if getattr(args, "enable_cudagraph_inputs_embeds", None) is not None:
462+
config.engine.enable_cudagraph_inputs_embeds = bool(args.enable_cudagraph_inputs_embeds)
461463
if (
462464
was_provided("prefill_cudagraph_max_len")
463465
and getattr(args, "prefill_cudagraph_max_len", None) is not None
@@ -534,6 +536,11 @@ def apply_engine_arg_overrides(engine: EngineConfig) -> None:
534536
if getattr(args, "enable_full_static_runner", None) is not None
535537
else True
536538
),
539+
enable_cudagraph_inputs_embeds=(
540+
bool(getattr(args, "enable_cudagraph_inputs_embeds", False))
541+
if getattr(args, "enable_cudagraph_inputs_embeds", None) is not None
542+
else False
543+
),
537544
prefill_cudagraph_max_len=(getattr(args, "prefill_cudagraph_max_len", None) or 0),
538545
enable_torch_compile=(
539546
bool(getattr(args, "enable_torch_compile", True))

0 commit comments

Comments
 (0)