fix(gemma4): add stop_gradient on MoE router top_k_indices#1238
Open
TrentCarter wants to merge 1 commit intoml-explore:mainfrom
Open
fix(gemma4): add stop_gradient on MoE router top_k_indices#1238TrentCarter wants to merge 1 commit intoml-explore:mainfrom
TrentCarter wants to merge 1 commit intoml-explore:mainfrom
Conversation
Block gradient through discrete expert selection in gemma4_text.DecoderLayer.__call__. Router weights remain learnable via top_k_weights (softmax over selected experts), so gating still trains; only the argmax-derived index path is detached. Reproducer (downstream): adapter SFT (LoRA) on Gemma-4 27B IT produces 8-10K char token loops at inference without this patch. A regression test is included that asserts the router projection weight still receives gradient on a tiny MoE-enabled gemma4 model. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Wrap
top_k_indicesinmx.stop_gradientinsidegemma4_text.DecoderLayer.__call__, just after the router call. Router weights remain learnable viatop_k_weights(softmax over the selected-experts subset), so the gating function still trains end-to-end; only the argmax-derived index path is detached.Motivation
Discovered while doing LoRA SFT on a Gemma-4 27B IT (4-bit MoE) checkpoint. Without this patch, after a few hundred training steps the adapter-loaded model exhibits severe inference instability (8-10K character token loops, fenced-code output never closes). Adding
stop_gradientontop_k_indicesremoves the failure mode while leaving routing learnable.The change is also a no-op on inference:
argpartitionalready returns int indices, so cost-free in eval.Minimal repro outline
mlx-community/gemma-4-26b-a4b-it-4bitwithmlx_lm.load(...).(Happy to share a more concrete repro on request — the downstream eval is internal but the failure mode is reproducible on any non-trivial adapter run against the 27B MoE variant.)
Test
Added
test_gemma4_moe_router_top_k_indices_no_gradintests/test_models.py. It builds a tiny MoE-enabled gemma4 model (hidden_size=8, num_experts=4, top_k_experts=2), runs a forward pass, and verifies viann.value_and_gradthat the router projection weight still receives non-zero gradient (i.e. the router stays learnable throughtop_k_weights). Runs in <0.1s.Notes
ml-explore/mlx-lmfor "gemma4 router stop_gradient" — found none. If this duplicates an in-flight discussion I missed, happy to close.