Skip to content

Commit bbe9c7e

Browse files
committed
Revert "Refactor graph input buffers (#18991)" (#19173)
1 parent 901957a commit bbe9c7e

File tree

8 files changed

+494
-700
lines changed

8 files changed

+494
-700
lines changed

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 6 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
import logging
2222
import os
2323
from contextlib import contextmanager
24-
from dataclasses import dataclass
2524
from functools import partial
26-
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
25+
from typing import TYPE_CHECKING, Callable, Optional, Union
2726

2827
import torch
2928
import tqdm
@@ -59,10 +58,9 @@
5958
ForwardBatch,
6059
ForwardMode,
6160
PPProxyTensors,
62-
compute_local_num_token_non_padded,
6361
enable_num_token_non_padded,
6462
)
65-
from sglang.srt.model_executor.input_buffers import ForwardInputBuffers
63+
from sglang.srt.model_executor.input_buffers import GraphInputBuffers
6664
from sglang.srt.multiplex.pdmux_context import get_current_stream_idx, get_stream_groups
6765
from sglang.srt.utils import (
6866
empty_context,
@@ -92,200 +90,6 @@
9290
if TYPE_CHECKING:
9391
from sglang.srt.model_executor.model_runner import ModelRunner
9492

95-
96-
@dataclass
97-
class DecodeInputBuffers(ForwardInputBuffers):
98-
99-
input_ids: torch.Tensor
100-
input_embeds: torch.Tensor
101-
req_pool_indices: torch.Tensor
102-
seq_lens: torch.Tensor
103-
seq_lens_cpu: torch.Tensor
104-
out_cache_loc: torch.Tensor
105-
positions: torch.Tensor
106-
mrope_positions: torch.Tensor
107-
num_token_non_padded: torch.Tensor
108-
custom_mask: torch.Tensor
109-
next_token_logits_buffer: torch.Tensor
110-
mamba_track_indices: Optional[torch.Tensor]
111-
mamba_track_mask: Optional[torch.Tensor]
112-
global_num_tokens_gpu: torch.Tensor
113-
global_num_tokens_for_logprob_gpu: torch.Tensor
114-
encoder_lens: Optional[torch.Tensor]
115-
pp_proxy_tensors: Optional[Dict[str, torch.Tensor]]
116-
117-
@classmethod
118-
def create(
119-
cls,
120-
*,
121-
device: torch.device,
122-
max_bs: int,
123-
max_num_token: int,
124-
hidden_size: int,
125-
vocab_size: int,
126-
dtype: torch.dtype,
127-
dp_size: int,
128-
pp_size: int,
129-
is_encoder_decoder: bool,
130-
require_mlp_tp_gather: bool,
131-
seq_len_fill_value: int,
132-
encoder_len_fill_value: int,
133-
num_tokens_per_bs: int,
134-
cache_loc_dtype: torch.dtype,
135-
enable_mamba_track: bool,
136-
) -> "DecodeInputBuffers":
137-
with torch.device(device):
138-
input_ids = torch.zeros((max_num_token,), dtype=torch.int64)
139-
input_embeds = torch.zeros((max_num_token, hidden_size), dtype=dtype)
140-
req_pool_indices = torch.zeros((max_bs,), dtype=torch.int32)
141-
seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32)
142-
out_cache_loc = torch.zeros((max_num_token,), dtype=cache_loc_dtype)
143-
positions = torch.zeros((max_num_token,), dtype=torch.int64)
144-
mrope_positions = torch.zeros((3, max_num_token), dtype=torch.int64)
145-
num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
146-
custom_mask = torch.ones(
147-
(max_bs * seq_len_fill_value + max_num_token) * num_tokens_per_bs,
148-
dtype=torch.bool,
149-
)
150-
next_token_logits_buffer = torch.zeros(
151-
(max_num_token, vocab_size),
152-
dtype=torch.float,
153-
)
154-
mamba_track_indices = (
155-
torch.zeros((max_bs,), dtype=torch.int64)
156-
if enable_mamba_track
157-
else None
158-
)
159-
mamba_track_mask = (
160-
torch.zeros((max_bs,), dtype=torch.bool) if enable_mamba_track else None
161-
)
162-
163-
if pp_size > 1:
164-
pp_proxy_tensors = {
165-
"hidden_states": torch.zeros((max_bs, hidden_size), dtype=dtype),
166-
"residual": torch.zeros((max_bs, hidden_size), dtype=dtype),
167-
}
168-
else:
169-
pp_proxy_tensors = None
170-
171-
if is_encoder_decoder:
172-
encoder_lens = torch.full(
173-
(max_bs,), encoder_len_fill_value, dtype=torch.int32
174-
)
175-
else:
176-
encoder_lens = None
177-
178-
if require_mlp_tp_gather:
179-
global_num_tokens_gpu = torch.zeros((dp_size,), dtype=torch.int32)
180-
global_num_tokens_for_logprob_gpu = torch.zeros(
181-
(dp_size,), dtype=torch.int32
182-
)
183-
else:
184-
global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
185-
global_num_tokens_for_logprob_gpu = torch.zeros((1,), dtype=torch.int32)
186-
187-
# Keep seq_lens_cpu as a true CPU tensor, like the old implementation.
188-
seq_lens_cpu = torch.full(
189-
(max_bs,),
190-
seq_len_fill_value,
191-
dtype=torch.int32,
192-
device="cpu",
193-
)
194-
195-
return cls(
196-
input_ids=input_ids,
197-
input_embeds=input_embeds,
198-
req_pool_indices=req_pool_indices,
199-
seq_lens=seq_lens,
200-
seq_lens_cpu=seq_lens_cpu,
201-
out_cache_loc=out_cache_loc,
202-
positions=positions,
203-
mrope_positions=mrope_positions,
204-
num_token_non_padded=num_token_non_padded,
205-
custom_mask=custom_mask,
206-
next_token_logits_buffer=next_token_logits_buffer,
207-
mamba_track_indices=mamba_track_indices,
208-
mamba_track_mask=mamba_track_mask,
209-
encoder_lens=encoder_lens,
210-
global_num_tokens_gpu=global_num_tokens_gpu,
211-
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu,
212-
pp_proxy_tensors=pp_proxy_tensors,
213-
)
214-
215-
def populate_from_forward_batch(
216-
self,
217-
*,
218-
forward_batch: ForwardBatch,
219-
raw_bs: int,
220-
raw_num_token: int,
221-
bs: int,
222-
seq_len_fill_value: int,
223-
require_gathered_buffer: bool,
224-
num_tokens_per_bs: int,
225-
nsa_enable_prefill_cp: bool,
226-
enable_num_token_non_padded_flag: bool,
227-
pp_proxy_tensors: Optional[PPProxyTensors] = None,
228-
):
229-
if bs != raw_bs:
230-
self.seq_lens.fill_(seq_len_fill_value)
231-
self.out_cache_loc.zero_()
232-
if self.mamba_track_indices is not None:
233-
self.mamba_track_indices.zero_()
234-
if self.mamba_track_mask is not None:
235-
self.mamba_track_mask.fill_(False)
236-
237-
# Common inputs
238-
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
239-
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
240-
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
241-
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
242-
self.positions[:raw_num_token].copy_(forward_batch.positions)
243-
244-
if (
245-
self.mamba_track_indices is not None
246-
and forward_batch.mamba_track_indices is not None
247-
):
248-
self.mamba_track_indices[:raw_bs].copy_(forward_batch.mamba_track_indices)
249-
if (
250-
self.mamba_track_mask is not None
251-
and forward_batch.mamba_track_mask is not None
252-
):
253-
self.mamba_track_mask[:raw_bs].copy_(forward_batch.mamba_track_mask)
254-
255-
if forward_batch.seq_lens_cpu is not None:
256-
if bs != raw_bs:
257-
self.seq_lens_cpu.fill_(seq_len_fill_value)
258-
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
259-
260-
if self.encoder_lens is not None and forward_batch.encoder_lens is not None:
261-
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
262-
263-
if forward_batch.mrope_positions is not None:
264-
self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
265-
266-
if require_gathered_buffer:
267-
self.global_num_tokens_gpu.fill_(bs * num_tokens_per_bs)
268-
self.global_num_tokens_for_logprob_gpu.fill_(bs * num_tokens_per_bs)
269-
270-
if enable_num_token_non_padded_flag:
271-
if require_gathered_buffer and not nsa_enable_prefill_cp:
272-
num_tokens_per_dp = bs * num_tokens_per_bs
273-
local = compute_local_num_token_non_padded(
274-
global_num_token_non_padded=forward_batch.num_token_non_padded,
275-
num_tokens_per_dp=num_tokens_per_dp,
276-
)
277-
self.num_token_non_padded.copy_(local)
278-
else:
279-
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
280-
281-
# Pipeline-parallel proxy tensors.
282-
if pp_proxy_tensors is not None and self.pp_proxy_tensors is not None:
283-
for key, buf in self.pp_proxy_tensors.items():
284-
src = pp_proxy_tensors.tensors[key]
285-
dim = src.shape[0]
286-
buf[:dim].copy_(src)
287-
288-
28993
# Detect whether the current forward pass is in capture mode
29094
is_capture_mode = False
29195

@@ -533,7 +337,7 @@ def __init__(self, model_runner: ModelRunner):
533337

534338
if self.require_gathered_buffer:
535339
assert self.require_mlp_tp_gather or self.require_attn_tp_gather
536-
self.buffers: DecodeInputBuffers = DecodeInputBuffers.create(
340+
self.buffers: GraphInputBuffers = GraphInputBuffers.create(
537341
device=self.device,
538342
max_bs=self.max_bs,
539343
max_num_token=self.max_num_token,
@@ -550,7 +354,6 @@ def __init__(self, model_runner: ModelRunner):
550354
cache_loc_dtype=self._cache_loc_dtype(),
551355
enable_mamba_track=enable_mamba_track,
552356
)
553-
self.buffers.share_buffers()
554357

555358
self.tbo_plugin = TboCudaGraphRunnerPlugin()
556359

@@ -753,7 +556,7 @@ def _create_device_graph(self):
753556
def capture_one_batch_size(
754557
self, bs: int, forward: Callable, stream_idx: Optional[int] = None
755558
):
756-
buffers: DecodeInputBuffers = self.buffers
559+
buffers: GraphInputBuffers = self.buffers
757560
graph = self._create_device_graph()
758561
stream = self.stream
759562
num_tokens = bs * self.num_tokens_per_bs
@@ -995,7 +798,7 @@ def replay_prepare(
995798
index = bisect.bisect_left(self.capture_bs, raw_bs)
996799
bs = self.capture_bs[index]
997800

998-
buffers.populate_from_forward_batch(
801+
seq_lens_cpu = buffers.populate_from_forward_batch(
999802
forward_batch=forward_batch,
1000803
raw_bs=raw_bs,
1001804
raw_num_token=raw_num_token,
@@ -1032,7 +835,7 @@ def replay_prepare(
1032835
buffers.encoder_lens[:bs] if self.is_encoder_decoder else None,
1033836
self.capture_forward_mode,
1034837
forward_batch.spec_info,
1035-
seq_lens_cpu=buffers.seq_lens_cpu[:bs],
838+
seq_lens_cpu=seq_lens_cpu,
1036839
)
1037840

1038841
# Store fields

0 commit comments

Comments
 (0)