Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Diffusion language models have shown promise for non-autoregressive text generat

## Example Launch Command

SGLang supports different DLLM algorithms such as `LowConfidence` and `JointThreshold`.
SGLang supports different DLLM algorithms such as `LowConfidence`, `JointThreshold`, and `FullAttnMultiBlock`.

```shell
python3 -m sglang.launch_server \
Expand Down Expand Up @@ -51,6 +51,24 @@ max_post_edit_steps: 16
penalty_lambda: 0
```

FullAttnMultiBlock Config (for bidirectional models like LLaDA and DREAM):

```yaml
# Confidence threshold for accepting predicted tokens
# Range: 0.0 - 1.0
threshold: 0.5
# Additional threshold increment per decoding step
block_add: 0.1
# Threshold for considering a token as "decoded"
decoded_thresh: 0.95
# Sub-block size for parallel decoding
sub_block_size: 32
# Number of iterations to delay before caching
cache_delay_iter: 2
# Interval for refreshing the attention cache
refresh_interval: 10000
```
Comment on lines +56 to +70
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration keys and descriptions for FullAttnMultiBlock in this documentation seem to be inconsistent with the implementation in python/sglang/srt/dllm/algorithm/full_attn_multi_block.py.

Specifically:

  • The key block_add in the YAML should likely be block_add_threshold. Its description "Additional threshold increment per decoding step" is also misleading. Based on the code, it's the "previous block progress threshold to add the next block".
  • The key decoded_thresh should likely be decoded_token_threshold. Its description "Threshold for considering a token as 'decoded'" is also not quite accurate. It's the "previous block progress threshold for full activation".
  • The key sub_block_size is documented here but does not appear to be used in the FullAttnMultiBlock algorithm implementation.

Could you please update the documentation to match the implementation for clarity and correctness? This will help users configure the algorithm correctly.

Suggested change
```yaml
# Confidence threshold for accepting predicted tokens
# Range: 0.0 - 1.0
threshold: 0.5
# Additional threshold increment per decoding step
block_add: 0.1
# Threshold for considering a token as "decoded"
decoded_thresh: 0.95
# Sub-block size for parallel decoding
sub_block_size: 32
# Number of iterations to delay before caching
cache_delay_iter: 2
# Interval for refreshing the attention cache
refresh_interval: 10000
```
# Confidence threshold for accepting predicted tokens
# Range: 0.0 - 1.0
threshold: 0.5
# Previous block progress threshold to add the next block
# Range: 0.0 - 1.0
block_add_threshold: 0.1
# Previous block progress threshold for a block to be considered fully active
# Range: 0.0 - 1.0
decoded_token_threshold: 0.95
# Number of iterations to delay before caching
cache_delay_iter: 2
# Interval for refreshing the attention cache
refresh_interval: 10000


## Example Client Code Snippet

Just like other supported models, diffusion language models can be used via the REST API or Python client.
Expand Down Expand Up @@ -104,8 +122,12 @@ curl -X POST "http://127.0.0.1:30000/generate" \

Below the supported models are summarized in a table.

| Model Family | Example Model | Description |
| -------------------------- | ---------------------------- | ---------------------------------------------------------------------------------------------------- |
| **LLaDA2.0 (mini, flash)** | `inclusionAI/LLaDA2.0-flash` | LLaDA2.0-flash is a diffusion language model featuring a 100B Mixture-of-Experts (MoE) architecture. |
| **SDAR (JetLM)** | `JetLM/SDAR-8B-Chat` | SDAR series diffusion language model (Chat), dense architecture. |
| **SDAR (JetLM)** | `JetLM/SDAR-30B-A3B-Chat` | SDAR series diffusion language model (Chat), MoE architecture. |
| Model Family | Example Model | Algorithm | Description |
| -------------------------- | ---------------------------- | ------------------- | ---------------------------------------------------------------------------------------------------- |
| **LLaDA2.0 (mini, flash)** | `inclusionAI/LLaDA2.0-mini` | LowConfidence | LLaDA2.0-mini is a diffusion language model with dense architecture. |
| **LLaDA2.0 (mini, flash)** | `inclusionAI/LLaDA2.0-flash` | LowConfidence | LLaDA2.0-flash is a diffusion language model featuring a 100B Mixture-of-Experts (MoE) architecture. |
| **LLaDA2.1-mini** | `inclusionAI/LLaDA2.1-mini` | LowConfidence | LLaDA2.1-mini is an improved version of LLaDA2.0-mini with better performance. |
| **d3LLM-LLaDA** | `d3LLM/d3LLM_LLaDA` | FullAttnMultiBlock | Bidirectional diffusion LLM based on LLaDA architecture with full attention. |
| **d3LLM-Dream** | `d3LLM/d3LLM_Dream` | FullAttnMultiBlock | Bidirectional diffusion LLM based on DREAM architecture with full attention. |
| **SDAR (JetLM)** | `JetLM/SDAR-8B-Chat` | JointThreshold | SDAR series diffusion language model (Chat), dense architecture. |
| **SDAR (JetLM)** | `JetLM/SDAR-30B-A3B-Chat` | JointThreshold | SDAR series diffusion language model (Chat), MoE architecture. |
126 changes: 126 additions & 0 deletions python/sglang/srt/dllm/algorithm/entropy_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List, Tuple, Union

import torch

from sglang.srt.dllm.algorithm.base import DllmAlgorithm
from sglang.srt.dllm.config import DllmConfig
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner


def sample_tokens_with_entropy(logits, temperature=1.0):
"""Compute entropy and sample tokens. Copied from d3LLM."""
original_probs = torch.softmax(logits, dim=-1)
log_probs = torch.log(original_probs + 1e-8)
entropy = -torch.sum(original_probs * log_probs, dim=-1)

if temperature == 0:
samples = torch.argmax(logits, dim=-1)
else:
scaled_logits = logits / temperature
# Convert to probabilities and sample
probs = torch.softmax(scaled_logits, dim=-1)
samples = torch.multinomial(probs, num_samples=1).squeeze(-1)

return entropy, samples


class EntropyThreshold(DllmAlgorithm):

def __init__(
self,
config: DllmConfig,
):
super().__init__(config)
self.threshold = config.algorithm_config.get("threshold", 0.95)

def run(
self,
model_runner: ModelRunner,
forward_batch: ForwardBatch,
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], List[torch.Tensor], bool]:
batch_size = forward_batch.batch_size
# Per-item token counts: supports variable-length sequences (needs_full_prefill)
item_lens = forward_batch.extend_seq_lens_cpu
offsets = [0]
for l in item_lens:
offsets.append(offsets[-1] + l)
start_list = []
mask_index = forward_batch.input_ids == self.mask_id

# Fast path: if there is no mask token, forward and save kv cache
if torch.sum(mask_index).item() == 0:
out = model_runner.forward(forward_batch, pp_proxy_tensors=None)
logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph

next_token_ids = []
return logits_output, next_token_ids, can_run_cuda_graph

# Calculate start positions for each batch item
for i in range(batch_size):
block_input_ids = forward_batch.input_ids[offsets[i] : offsets[i + 1]]
block_mask_index = block_input_ids == self.mask_id
start = item_lens[i] - torch.sum(block_mask_index).item()
start_list.append(start)

nfe = 0
for _ in range(self.block_size):
mask_index = forward_batch.input_ids == self.mask_id
if torch.sum(mask_index).item() == 0:
break

out = model_runner.forward(forward_batch, pp_proxy_tensors=None)
nfe += 1
logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph
for batch_id in range(batch_size):
s, e = offsets[batch_id], offsets[batch_id + 1]
block_input_ids = forward_batch.input_ids[s:e]
block_mask_index = block_input_ids == self.mask_id
if torch.sum(block_mask_index).item() == 0:
continue
curr_logits = logits_output.full_logits[s:e]

# Entropy-based selection (matching d3LLM's entropy_threshold algorithm)
mask_logits = curr_logits[block_mask_index]
entropy, x0 = sample_tokens_with_entropy(mask_logits, temperature=0)

x = block_input_ids.clone()
full_entropy = torch.full_like(
block_input_ids, float("inf"), dtype=curr_logits.dtype
)

x[block_mask_index] = x0
full_entropy[block_mask_index] = entropy

num_mask = block_mask_index.sum().item()
selected_entropy, select_index = torch.topk(
full_entropy, num_mask, largest=False
)
transfer_index = torch.zeros_like(block_input_ids, dtype=torch.bool)

# Always accept the lowest-entropy token; accept others if entropy < threshold
transfer_index[select_index[0]] = True
for k in range(1, num_mask):
if selected_entropy[k] < self.threshold:
transfer_index[select_index[k]] = True
else:
transfer_index[select_index[k]] = False

block_input_ids[transfer_index] = x[transfer_index]

out = model_runner.forward(forward_batch, pp_proxy_tensors=None)
nfe += 1
logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph
# Build per-sequence next_token_ids using variable-length offsets
next_token_ids_list = [
forward_batch.input_ids[offsets[i] + start_list[i] : offsets[i + 1]]
for i in range(batch_size)
]

# Token per forward: attach nfe so it flows via customized_info -> meta_info
logits_output.customized_info = {"nfe": [nfe] * batch_size}
return logits_output, next_token_ids_list, can_run_cuda_graph


Algorithm = EntropyThreshold
Loading
Loading