Problem
In the pipeline-parallel path, every rank ends up running lm_head on the same broadcast hidden state, e.g. mlx_lm/models/deepseek_v3.py:357-361:
# Broadcast h while keeping it in the graph
if pipeline_size > 1:
h = mx.distributed.all_gather(h)[: h.shape[0]]
return self.norm(h)
…then in the outer Model.__call__:
out = self.model(inputs, cache)
return self.lm_head(out)
For non-rank-0 ranks (PipelineMixin's "rank=0 gets the last layers" convention), the produced logits are identical to rank 0's and aren't actually used downstream — the lm_head call there is pure waste.
Cost
On a 2-node Apple Silicon PP setup running a ~400B MoE, the lm_head weight read is ~500MB per token. Skipping it on non-final ranks saves ~1.5 ms/token in our cluster (we've been carrying a fork patch for ~5 weeks across the Qwen3.5 family).
Proposed approach
Add a guard at the top of Model.__call__ in each PipelineMixin-using model (deepseek_v2/v3, glm4_moe, glm4_moe_lite, ministral3):
out = self.model(inputs, cache)
if self.model.pipeline_rank != 0:
return out
return self.lm_head(out)
~5 lines per model. (ministral3 has a tie_word_embeddings branch, so the guard wraps both paths — still ~5 lines, but not a literal copy of the snippet above.)
Open design question
This makes non-rank-0 ranks return shape (B, L, hidden_size) while rank-0 returns (B, L, vocab_size) — different shapes, different compile graphs across ranks. In practice each rank already compiles a different graph (different layer slices), but if there's a deliberate intent to keep the per-rank compiled graph homogeneous I'd rather know before sinking work into the PR.
Alternative shape: do the skip inside PipelineMixin / a helper, gated by an explicit elide_output_head kwarg so callers opt in.
Would the team accept either of these? Happy to PR whichever direction you prefer.
Problem
In the pipeline-parallel path, every rank ends up running
lm_headon the same broadcast hidden state, e.g.mlx_lm/models/deepseek_v3.py:357-361:…then in the outer
Model.__call__:For non-rank-0 ranks (PipelineMixin's "rank=0 gets the last layers" convention), the produced logits are identical to rank 0's and aren't actually used downstream — the lm_head call there is pure waste.
Cost
On a 2-node Apple Silicon PP setup running a ~400B MoE, the lm_head weight read is ~500MB per token. Skipping it on non-final ranks saves ~1.5 ms/token in our cluster (we've been carrying a fork patch for ~5 weeks across the Qwen3.5 family).
Proposed approach
Add a guard at the top of
Model.__call__in each PipelineMixin-using model (deepseek_v2/v3, glm4_moe, glm4_moe_lite, ministral3):~5 lines per model. (
ministral3has atie_word_embeddingsbranch, so the guard wraps both paths — still ~5 lines, but not a literal copy of the snippet above.)Open design question
This makes non-rank-0 ranks return shape
(B, L, hidden_size)while rank-0 returns(B, L, vocab_size)— different shapes, different compile graphs across ranks. In practice each rank already compiles a different graph (different layer slices), but if there's a deliberate intent to keep the per-rank compiled graph homogeneous I'd rather know before sinking work into the PR.Alternative shape: do the skip inside
PipelineMixin/ a helper, gated by an explicitelide_output_headkwarg so callers opt in.Would the team accept either of these? Happy to PR whichever direction you prefer.