Skip to content

CUDA graph replay can fail with block_tables shape mismatch #190

@ilrewrite

Description

@ilrewrite

Summary

Decode CUDA graph replay can fail under higher-concurrency / longer-context workloads because the captured block_tables buffer can be narrower than the runtime context.block_tables.

The failure shows up during graph replay preparation when copying context.block_tables into the captured graph buffer.

Observed error

RuntimeError: The expanded size of the tensor (16) must match the existing size (17)
at non-singleton dimension 1. Target sizes: [32, 16]. Tensor sizes: [32, 17]

Root cause

In capture_cudagraph(), the graph buffer is allocated with:

max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)

But during decode replay, context.block_tables.size(1) can be one column wider than that captured width. When replay prep does:

graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables

the assignment fails with a shape mismatch.

There is also a secondary correctness issue: the graph buffer is reused across steps but block_tables is not cleared before copying the current step's values.

Proposed fix

  1. Allocate one extra block_tables column during graph capture.
  2. Clear the captured block_tables buffer with -1 before copying the current step's block tables.

Suggested patch

graph_vars["block_tables"].fill_(-1)
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables

and

max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + 1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions