Skip to content

fix(gemma4): add stop_gradient on MoE router top_k_indices#1238

Open
TrentCarter wants to merge 1 commit intoml-explore:mainfrom
TrentCarter:fix/gemma4-moe-router-stop-gradient
Open

fix(gemma4): add stop_gradient on MoE router top_k_indices#1238
TrentCarter wants to merge 1 commit intoml-explore:mainfrom
TrentCarter:fix/gemma4-moe-router-stop-gradient

Conversation

@TrentCarter
Copy link
Copy Markdown

Summary

Wrap top_k_indices in mx.stop_gradient inside gemma4_text.DecoderLayer.__call__, just after the router call. Router weights remain learnable via top_k_weights (softmax over the selected-experts subset), so the gating function still trains end-to-end; only the argmax-derived index path is detached.

top_k_indices, top_k_weights = self.router(h)
top_k_indices = mx.stop_gradient(top_k_indices)  # <-- added
h2 = self.experts(h2, top_k_indices, top_k_weights)

Motivation

Discovered while doing LoRA SFT on a Gemma-4 27B IT (4-bit MoE) checkpoint. Without this patch, after a few hundred training steps the adapter-loaded model exhibits severe inference instability (8-10K character token loops, fenced-code output never closes). Adding stop_gradient on top_k_indices removes the failure mode while leaving routing learnable.

The change is also a no-op on inference: argpartition already returns int indices, so cost-free in eval.

Minimal repro outline

  1. Load mlx-community/gemma-4-26b-a4b-it-4bit with mlx_lm.load(...).
  2. Train a small LoRA adapter (function-style code SFT corpus is sufficient; ~hundreds of short samples).
  3. Run inference on held-out tasks: with the unpatched model the assistant turn drifts into repeated tokens after the fence-close. Patched: clean termination.

(Happy to share a more concrete repro on request — the downstream eval is internal but the failure mode is reproducible on any non-trivial adapter run against the 27B MoE variant.)

Test

Added test_gemma4_moe_router_top_k_indices_no_grad in tests/test_models.py. It builds a tiny MoE-enabled gemma4 model (hidden_size=8, num_experts=4, top_k_experts=2), runs a forward pass, and verifies via nn.value_and_grad that the router projection weight still receives non-zero gradient (i.e. the router stays learnable through top_k_weights). Runs in <0.1s.

Notes

  • Searched open and closed issues/PRs on ml-explore/mlx-lm for "gemma4 router stop_gradient" — found none. If this duplicates an in-flight discussion I missed, happy to close.
  • Black-formatted; pre-commit clean on the touched files.

Block gradient through discrete expert selection in
gemma4_text.DecoderLayer.__call__. Router weights remain learnable
via top_k_weights (softmax over selected experts), so gating still
trains; only the argmax-derived index path is detached.

Reproducer (downstream): adapter SFT (LoRA) on Gemma-4 27B IT
produces 8-10K char token loops at inference without this patch.
A regression test is included that asserts the router projection
weight still receives gradient on a tiny MoE-enabled gemma4 model.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant