Skip to content

Skip lm_head on non-final pipeline-parallel ranks? #1203

@adurham

Description

@adurham

Problem

In the pipeline-parallel path, every rank ends up running lm_head on the same broadcast hidden state, e.g. mlx_lm/models/deepseek_v3.py:357-361:

# Broadcast h while keeping it in the graph
if pipeline_size > 1:
    h = mx.distributed.all_gather(h)[: h.shape[0]]
return self.norm(h)

…then in the outer Model.__call__:

out = self.model(inputs, cache)
return self.lm_head(out)

For non-rank-0 ranks (PipelineMixin's "rank=0 gets the last layers" convention), the produced logits are identical to rank 0's and aren't actually used downstream — the lm_head call there is pure waste.

Cost

On a 2-node Apple Silicon PP setup running a ~400B MoE, the lm_head weight read is ~500MB per token. Skipping it on non-final ranks saves ~1.5 ms/token in our cluster (we've been carrying a fork patch for ~5 weeks across the Qwen3.5 family).

Proposed approach

Add a guard at the top of Model.__call__ in each PipelineMixin-using model (deepseek_v2/v3, glm4_moe, glm4_moe_lite, ministral3):

out = self.model(inputs, cache)
if self.model.pipeline_rank != 0:
    return out
return self.lm_head(out)

~5 lines per model. (ministral3 has a tie_word_embeddings branch, so the guard wraps both paths — still ~5 lines, but not a literal copy of the snippet above.)

Open design question

This makes non-rank-0 ranks return shape (B, L, hidden_size) while rank-0 returns (B, L, vocab_size) — different shapes, different compile graphs across ranks. In practice each rank already compiles a different graph (different layer slices), but if there's a deliberate intent to keep the per-rank compiled graph homogeneous I'd rather know before sinking work into the PR.

Alternative shape: do the skip inside PipelineMixin / a helper, gated by an explicit elide_output_head kwarg so callers opt in.

Would the team accept either of these? Happy to PR whichever direction you prefer.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions