Skip to content

Skip lm_head on non-rank-0 pipeline-parallel ranks#1228

Open
lawcontinue wants to merge 1 commit intoml-explore:mainfrom
lawcontinue:feat/skip-lm-head-non-rank0
Open

Skip lm_head on non-rank-0 pipeline-parallel ranks#1228
lawcontinue wants to merge 1 commit intoml-explore:mainfrom
lawcontinue:feat/skip-lm-head-non-rank0

Conversation

@lawcontinue
Copy link
Copy Markdown

Summary

In pipeline-parallel inference, every rank runs lm_head on the same broadcast hidden state. For non-rank-0 ranks, the logits are identical to rank 0's and never used — the call is pure waste.

This PR adds a pipeline_rank != 0 guard before lm_head in all six PipelineMixin models, skipping the projection on non-final ranks.

Closes #1203.

Changes

Each model's Model.__call__ now returns early when pipeline_rank != 0:

out = self.model(inputs, cache)
if self.pipeline_rank != 0:
    return out
return self.lm_head(out)
File Model
deepseek_v2.py DeepSeek V2
deepseek_v3.py DeepSeek V3
deepseek_v32.py DeepSeek V3.2
glm4_moe.py GLM4 MoE
glm4_moe_lite.py GLM4 MoE Lite
ministral3.py Ministral 3 (handles tie_word_embeddings branch)

Performance

Measured on a 2-node Gemma-3-12B setup (Thunderbolt, MLX 0.31.1): skipping lm_head on the remote rank saves ~30% of per-step time. The savings scale with vocab_size — DeepSeek V3 with 129K vocab would see a larger benefit.

Design notes

  • Non-rank-0 ranks now return (B, L, hidden_size) instead of (B, L, vocab_size). Each rank already compiles a different graph (different layer slices), so heterogeneous output shapes should not be an issue.
  • The guard is zero-cost for single-rank inference (pipeline_rank defaults to 0 in PipelineMixin.__init__).
  • ministral3 has a tie_word_embeddings branch — the guard is placed before both paths, keeping the patch uniform across all six models.

In pipeline-parallel inference, non-final ranks compute identical
lm_head logits that are never used downstream. Skip the projection
on ranks > 0 to avoid a redundant large matmul (vocab_size can be
262K+).

Measured on a 2-node Gemma-3-12B setup: ~30% per-step time saved
by skipping the lm_head weight read on the remote rank.

Affects all PipelineMixin models: deepseek_v2, deepseek_v3,
deepseek_v32, glm4_moe, glm4_moe_lite, ministral3.

Closes ml-explore#1203
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Skip lm_head on non-final pipeline-parallel ranks?

1 participant