2121import logging
2222import os
2323from contextlib import contextmanager
24- from dataclasses import dataclass
2524from functools import partial
26- from typing import TYPE_CHECKING , Callable , Dict , Optional , Union
25+ from typing import TYPE_CHECKING , Callable , Optional , Union
2726
2827import torch
2928import tqdm
5958 ForwardBatch ,
6059 ForwardMode ,
6160 PPProxyTensors ,
62- compute_local_num_token_non_padded ,
6361 enable_num_token_non_padded ,
6462)
65- from sglang .srt .model_executor .input_buffers import ForwardInputBuffers
63+ from sglang .srt .model_executor .input_buffers import GraphInputBuffers
6664from sglang .srt .multiplex .pdmux_context import get_current_stream_idx , get_stream_groups
6765from sglang .srt .utils import (
6866 empty_context ,
9290if TYPE_CHECKING :
9391 from sglang .srt .model_executor .model_runner import ModelRunner
9492
95-
96- @dataclass
97- class DecodeInputBuffers (ForwardInputBuffers ):
98-
99- input_ids : torch .Tensor
100- input_embeds : torch .Tensor
101- req_pool_indices : torch .Tensor
102- seq_lens : torch .Tensor
103- seq_lens_cpu : torch .Tensor
104- out_cache_loc : torch .Tensor
105- positions : torch .Tensor
106- mrope_positions : torch .Tensor
107- num_token_non_padded : torch .Tensor
108- custom_mask : torch .Tensor
109- next_token_logits_buffer : torch .Tensor
110- mamba_track_indices : Optional [torch .Tensor ]
111- mamba_track_mask : Optional [torch .Tensor ]
112- global_num_tokens_gpu : torch .Tensor
113- global_num_tokens_for_logprob_gpu : torch .Tensor
114- encoder_lens : Optional [torch .Tensor ]
115- pp_proxy_tensors : Optional [Dict [str , torch .Tensor ]]
116-
117- @classmethod
118- def create (
119- cls ,
120- * ,
121- device : torch .device ,
122- max_bs : int ,
123- max_num_token : int ,
124- hidden_size : int ,
125- vocab_size : int ,
126- dtype : torch .dtype ,
127- dp_size : int ,
128- pp_size : int ,
129- is_encoder_decoder : bool ,
130- require_mlp_tp_gather : bool ,
131- seq_len_fill_value : int ,
132- encoder_len_fill_value : int ,
133- num_tokens_per_bs : int ,
134- cache_loc_dtype : torch .dtype ,
135- enable_mamba_track : bool ,
136- ) -> "DecodeInputBuffers" :
137- with torch .device (device ):
138- input_ids = torch .zeros ((max_num_token ,), dtype = torch .int64 )
139- input_embeds = torch .zeros ((max_num_token , hidden_size ), dtype = dtype )
140- req_pool_indices = torch .zeros ((max_bs ,), dtype = torch .int32 )
141- seq_lens = torch .full ((max_bs ,), seq_len_fill_value , dtype = torch .int32 )
142- out_cache_loc = torch .zeros ((max_num_token ,), dtype = cache_loc_dtype )
143- positions = torch .zeros ((max_num_token ,), dtype = torch .int64 )
144- mrope_positions = torch .zeros ((3 , max_num_token ), dtype = torch .int64 )
145- num_token_non_padded = torch .zeros ((1 ,), dtype = torch .int32 )
146- custom_mask = torch .ones (
147- (max_bs * seq_len_fill_value + max_num_token ) * num_tokens_per_bs ,
148- dtype = torch .bool ,
149- )
150- next_token_logits_buffer = torch .zeros (
151- (max_num_token , vocab_size ),
152- dtype = torch .float ,
153- )
154- mamba_track_indices = (
155- torch .zeros ((max_bs ,), dtype = torch .int64 )
156- if enable_mamba_track
157- else None
158- )
159- mamba_track_mask = (
160- torch .zeros ((max_bs ,), dtype = torch .bool ) if enable_mamba_track else None
161- )
162-
163- if pp_size > 1 :
164- pp_proxy_tensors = {
165- "hidden_states" : torch .zeros ((max_bs , hidden_size ), dtype = dtype ),
166- "residual" : torch .zeros ((max_bs , hidden_size ), dtype = dtype ),
167- }
168- else :
169- pp_proxy_tensors = None
170-
171- if is_encoder_decoder :
172- encoder_lens = torch .full (
173- (max_bs ,), encoder_len_fill_value , dtype = torch .int32
174- )
175- else :
176- encoder_lens = None
177-
178- if require_mlp_tp_gather :
179- global_num_tokens_gpu = torch .zeros ((dp_size ,), dtype = torch .int32 )
180- global_num_tokens_for_logprob_gpu = torch .zeros (
181- (dp_size ,), dtype = torch .int32
182- )
183- else :
184- global_num_tokens_gpu = torch .zeros ((1 ,), dtype = torch .int32 )
185- global_num_tokens_for_logprob_gpu = torch .zeros ((1 ,), dtype = torch .int32 )
186-
187- # Keep seq_lens_cpu as a true CPU tensor, like the old implementation.
188- seq_lens_cpu = torch .full (
189- (max_bs ,),
190- seq_len_fill_value ,
191- dtype = torch .int32 ,
192- device = "cpu" ,
193- )
194-
195- return cls (
196- input_ids = input_ids ,
197- input_embeds = input_embeds ,
198- req_pool_indices = req_pool_indices ,
199- seq_lens = seq_lens ,
200- seq_lens_cpu = seq_lens_cpu ,
201- out_cache_loc = out_cache_loc ,
202- positions = positions ,
203- mrope_positions = mrope_positions ,
204- num_token_non_padded = num_token_non_padded ,
205- custom_mask = custom_mask ,
206- next_token_logits_buffer = next_token_logits_buffer ,
207- mamba_track_indices = mamba_track_indices ,
208- mamba_track_mask = mamba_track_mask ,
209- encoder_lens = encoder_lens ,
210- global_num_tokens_gpu = global_num_tokens_gpu ,
211- global_num_tokens_for_logprob_gpu = global_num_tokens_for_logprob_gpu ,
212- pp_proxy_tensors = pp_proxy_tensors ,
213- )
214-
215- def populate_from_forward_batch (
216- self ,
217- * ,
218- forward_batch : ForwardBatch ,
219- raw_bs : int ,
220- raw_num_token : int ,
221- bs : int ,
222- seq_len_fill_value : int ,
223- require_gathered_buffer : bool ,
224- num_tokens_per_bs : int ,
225- nsa_enable_prefill_cp : bool ,
226- enable_num_token_non_padded_flag : bool ,
227- pp_proxy_tensors : Optional [PPProxyTensors ] = None ,
228- ):
229- if bs != raw_bs :
230- self .seq_lens .fill_ (seq_len_fill_value )
231- self .out_cache_loc .zero_ ()
232- if self .mamba_track_indices is not None :
233- self .mamba_track_indices .zero_ ()
234- if self .mamba_track_mask is not None :
235- self .mamba_track_mask .fill_ (False )
236-
237- # Common inputs
238- self .input_ids [:raw_num_token ].copy_ (forward_batch .input_ids )
239- self .req_pool_indices [:raw_bs ].copy_ (forward_batch .req_pool_indices )
240- self .seq_lens [:raw_bs ].copy_ (forward_batch .seq_lens )
241- self .out_cache_loc [:raw_num_token ].copy_ (forward_batch .out_cache_loc )
242- self .positions [:raw_num_token ].copy_ (forward_batch .positions )
243-
244- if (
245- self .mamba_track_indices is not None
246- and forward_batch .mamba_track_indices is not None
247- ):
248- self .mamba_track_indices [:raw_bs ].copy_ (forward_batch .mamba_track_indices )
249- if (
250- self .mamba_track_mask is not None
251- and forward_batch .mamba_track_mask is not None
252- ):
253- self .mamba_track_mask [:raw_bs ].copy_ (forward_batch .mamba_track_mask )
254-
255- if forward_batch .seq_lens_cpu is not None :
256- if bs != raw_bs :
257- self .seq_lens_cpu .fill_ (seq_len_fill_value )
258- self .seq_lens_cpu [:raw_bs ].copy_ (forward_batch .seq_lens_cpu )
259-
260- if self .encoder_lens is not None and forward_batch .encoder_lens is not None :
261- self .encoder_lens [:raw_bs ].copy_ (forward_batch .encoder_lens )
262-
263- if forward_batch .mrope_positions is not None :
264- self .mrope_positions [:, :raw_num_token ].copy_ (forward_batch .mrope_positions )
265-
266- if require_gathered_buffer :
267- self .global_num_tokens_gpu .fill_ (bs * num_tokens_per_bs )
268- self .global_num_tokens_for_logprob_gpu .fill_ (bs * num_tokens_per_bs )
269-
270- if enable_num_token_non_padded_flag :
271- if require_gathered_buffer and not nsa_enable_prefill_cp :
272- num_tokens_per_dp = bs * num_tokens_per_bs
273- local = compute_local_num_token_non_padded (
274- global_num_token_non_padded = forward_batch .num_token_non_padded ,
275- num_tokens_per_dp = num_tokens_per_dp ,
276- )
277- self .num_token_non_padded .copy_ (local )
278- else :
279- self .num_token_non_padded .copy_ (forward_batch .num_token_non_padded )
280-
281- # Pipeline-parallel proxy tensors.
282- if pp_proxy_tensors is not None and self .pp_proxy_tensors is not None :
283- for key , buf in self .pp_proxy_tensors .items ():
284- src = pp_proxy_tensors .tensors [key ]
285- dim = src .shape [0 ]
286- buf [:dim ].copy_ (src )
287-
288-
28993# Detect whether the current forward pass is in capture mode
29094is_capture_mode = False
29195
@@ -533,7 +337,7 @@ def __init__(self, model_runner: ModelRunner):
533337
534338 if self .require_gathered_buffer :
535339 assert self .require_mlp_tp_gather or self .require_attn_tp_gather
536- self .buffers : DecodeInputBuffers = DecodeInputBuffers .create (
340+ self .buffers : GraphInputBuffers = GraphInputBuffers .create (
537341 device = self .device ,
538342 max_bs = self .max_bs ,
539343 max_num_token = self .max_num_token ,
@@ -550,7 +354,6 @@ def __init__(self, model_runner: ModelRunner):
550354 cache_loc_dtype = self ._cache_loc_dtype (),
551355 enable_mamba_track = enable_mamba_track ,
552356 )
553- self .buffers .share_buffers ()
554357
555358 self .tbo_plugin = TboCudaGraphRunnerPlugin ()
556359
@@ -753,7 +556,7 @@ def _create_device_graph(self):
753556 def capture_one_batch_size (
754557 self , bs : int , forward : Callable , stream_idx : Optional [int ] = None
755558 ):
756- buffers : DecodeInputBuffers = self .buffers
559+ buffers : GraphInputBuffers = self .buffers
757560 graph = self ._create_device_graph ()
758561 stream = self .stream
759562 num_tokens = bs * self .num_tokens_per_bs
@@ -995,7 +798,7 @@ def replay_prepare(
995798 index = bisect .bisect_left (self .capture_bs , raw_bs )
996799 bs = self .capture_bs [index ]
997800
998- buffers .populate_from_forward_batch (
801+ seq_lens_cpu = buffers .populate_from_forward_batch (
999802 forward_batch = forward_batch ,
1000803 raw_bs = raw_bs ,
1001804 raw_num_token = raw_num_token ,
@@ -1032,7 +835,7 @@ def replay_prepare(
1032835 buffers .encoder_lens [:bs ] if self .is_encoder_decoder else None ,
1033836 self .capture_forward_mode ,
1034837 forward_batch .spec_info ,
1035- seq_lens_cpu = buffers . seq_lens_cpu [: bs ] ,
838+ seq_lens_cpu = seq_lens_cpu ,
1036839 )
1037840
1038841 # Store fields
0 commit comments