Skip to content

Commit b9b247d

Browse files
committed
Support per-expert MoE checkpoints in qwen3_5_moe.sanitize
Qwen-org Qwen3.6 MoE checkpoints (Qwen/Qwen3.6-35B-A3B and -A3B-FP8) ship expert MLPs as one tensor per expert per projection: model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight Qwen3_5MoeModel.sanitize previously expected the combined-format layout (experts.gate_up_proj / experts.down_proj) produced by mlx_lm.convert and shipped by mlx-community redistributions. Loading a Qwen-org checkpoint directly fails strict load_weights with thousands of unexpected keys (e.g. 30720 = 256 experts x 40 layers x 3 projections for 35B-A3B). Add a second branch in sanitize that detects the per-expert layout and stacks the per-expert tensors along axis 0 into the same switch_mlp.* form the combined-format branch produces. Pre-stacked checkpoints take the original branch unchanged. Defensive contiguity check raises ValueError on non-contiguous indices (e.g. shipping experts {0,1,3} but skipping 2) so a malformed checkpoint fails loud rather than silently dropping experts. Tests in tests/test_models.py: - test_qwen3_5_moe_per_expert_weights_stack_to_switch_mlp (positive) - test_qwen3_5_moe_per_expert_gap_raises (defensive) - test_qwen3_5_moe_combined_format_still_splits_to_switch_mlp (regression) Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
1 parent ed1fca4 commit b9b247d

2 files changed

Lines changed: 215 additions & 0 deletions

File tree

mlx_lm/models/qwen3_5_moe.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from dataclasses import dataclass
44

5+
import mlx.core as mx
6+
57
from .base import BaseModelArgs
68
from .qwen3_5 import Model as Qwen3_5Model
79

@@ -48,5 +50,42 @@ def sanitize(self, weights):
4850
new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop(
4951
f"{prefix}.experts.down_proj"
5052
)
53+
elif f"{prefix}.experts.0.gate_proj.weight" in new_weights:
54+
# Per-expert layout (Qwen3.6 MoE FP8/bf16): one tensor per expert
55+
# per projection. Collect all matching keys for this layer,
56+
# validate the index range is contiguous from 0, then stack along
57+
# axis 0 into the same shape the combined-format branch above
58+
# produces.
59+
experts_prefix = f"{prefix}.experts."
60+
gate_suffix = ".gate_proj.weight"
61+
indices = set()
62+
for key in new_weights:
63+
if key.startswith(experts_prefix) and key.endswith(gate_suffix):
64+
tail = key[len(experts_prefix) : -len(gate_suffix)]
65+
if tail.isdigit():
66+
indices.add(int(tail))
67+
expected = set(range(len(indices)))
68+
if indices != expected:
69+
missing = sorted(expected - indices)
70+
extra = sorted(indices - expected)
71+
raise ValueError(
72+
f"Per-expert MoE weights at {prefix}.experts have "
73+
f"non-contiguous indices: missing={missing}, "
74+
f"unexpected={extra}."
75+
)
76+
gates, ups, downs = [], [], []
77+
for e in range(len(indices)):
78+
gates.append(
79+
new_weights.pop(f"{prefix}.experts.{e}.gate_proj.weight")
80+
)
81+
ups.append(
82+
new_weights.pop(f"{prefix}.experts.{e}.up_proj.weight")
83+
)
84+
downs.append(
85+
new_weights.pop(f"{prefix}.experts.{e}.down_proj.weight")
86+
)
87+
new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = mx.stack(gates)
88+
new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = mx.stack(ups)
89+
new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = mx.stack(downs)
5190

5291
return self.language_model.sanitize(new_weights)

tests/test_models.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,182 @@ def test_qwen3_5_family_convert_then_load_norm_not_shift_twice(self):
649649
mx.array_equal(loaded[mlx_norm_key], converted[mlx_norm_key])
650650
)
651651

652+
def test_qwen3_5_moe_per_expert_weights_stack_to_switch_mlp(self):
653+
# Qwen3.6 MoE checkpoints (e.g. Qwen/Qwen3.6-35B-A3B-FP8) ship expert
654+
# MLPs as one tensor per expert per projection rather than as a single
655+
# combined ``experts.gate_up_proj`` / ``experts.down_proj`` tensor.
656+
# The sanitize step must stack them into the combined ``switch_mlp.*``
657+
# form that ``Qwen3_5MoeSparseMlp.load_weights`` expects.
658+
from mlx_lm.models import qwen3_5_moe
659+
660+
text_config = {
661+
"model_type": "qwen3_5_moe_text",
662+
"hidden_size": 4,
663+
"intermediate_size": 8,
664+
"num_hidden_layers": 1,
665+
"num_attention_heads": 1,
666+
"num_key_value_heads": 1,
667+
"rms_norm_eps": 1e-5,
668+
"vocab_size": 16,
669+
"linear_num_value_heads": 1,
670+
"linear_num_key_heads": 1,
671+
"linear_key_head_dim": 4,
672+
"linear_value_head_dim": 4,
673+
"linear_conv_kernel_dim": 1,
674+
"full_attention_interval": 1,
675+
"tie_word_embeddings": False,
676+
"max_position_embeddings": 16,
677+
"num_experts": 2,
678+
"num_experts_per_tok": 1,
679+
"moe_intermediate_size": 6,
680+
"shared_expert_intermediate_size": 6,
681+
}
682+
args = qwen3_5_moe.ModelArgs.from_dict(
683+
{"model_type": "qwen3_5_moe", "text_config": text_config}
684+
)
685+
model = qwen3_5_moe.Model(args)
686+
687+
prefix = "model.language_model.layers.0.mlp.experts"
688+
per_expert = {
689+
f"{prefix}.0.gate_proj.weight": mx.full((6, 4), 1.0),
690+
f"{prefix}.0.up_proj.weight": mx.full((6, 4), 2.0),
691+
f"{prefix}.0.down_proj.weight": mx.full((4, 6), 3.0),
692+
f"{prefix}.1.gate_proj.weight": mx.full((6, 4), 4.0),
693+
f"{prefix}.1.up_proj.weight": mx.full((6, 4), 5.0),
694+
f"{prefix}.1.down_proj.weight": mx.full((4, 6), 6.0),
695+
}
696+
sanitized = model.sanitize(dict(per_expert))
697+
698+
out_prefix = "language_model.model.layers.0.mlp.switch_mlp"
699+
gate_key = f"{out_prefix}.gate_proj.weight"
700+
up_key = f"{out_prefix}.up_proj.weight"
701+
down_key = f"{out_prefix}.down_proj.weight"
702+
703+
self.assertIn(gate_key, sanitized)
704+
self.assertIn(up_key, sanitized)
705+
self.assertIn(down_key, sanitized)
706+
self.assertEqual(sanitized[gate_key].shape, (2, 6, 4))
707+
self.assertEqual(sanitized[up_key].shape, (2, 6, 4))
708+
self.assertEqual(sanitized[down_key].shape, (2, 4, 6))
709+
# Per-expert keys must not leak through.
710+
self.assertFalse(any(".experts.0." in k for k in sanitized))
711+
self.assertFalse(any(".experts.1." in k for k in sanitized))
712+
# Stacking preserves per-expert content along axis 0.
713+
self.assertTrue(mx.array_equal(sanitized[gate_key][0], per_expert[f"{prefix}.0.gate_proj.weight"]))
714+
self.assertTrue(mx.array_equal(sanitized[gate_key][1], per_expert[f"{prefix}.1.gate_proj.weight"]))
715+
self.assertTrue(mx.array_equal(sanitized[down_key][1], per_expert[f"{prefix}.1.down_proj.weight"]))
716+
717+
def test_qwen3_5_moe_per_expert_gap_raises(self):
718+
# Defensive: if a per-expert checkpoint has non-contiguous expert indices
719+
# (e.g. ships 0, 1, 3 but not 2), sanitize must fail loudly rather than
720+
# silently dropping expert 3 — which would replicate the strict-load
721+
# failure this branch is meant to fix.
722+
from mlx_lm.models import qwen3_5_moe
723+
724+
text_config = {
725+
"model_type": "qwen3_5_moe_text",
726+
"hidden_size": 4,
727+
"intermediate_size": 8,
728+
"num_hidden_layers": 1,
729+
"num_attention_heads": 1,
730+
"num_key_value_heads": 1,
731+
"rms_norm_eps": 1e-5,
732+
"vocab_size": 16,
733+
"linear_num_value_heads": 1,
734+
"linear_num_key_heads": 1,
735+
"linear_key_head_dim": 4,
736+
"linear_value_head_dim": 4,
737+
"linear_conv_kernel_dim": 1,
738+
"full_attention_interval": 1,
739+
"tie_word_embeddings": False,
740+
"max_position_embeddings": 16,
741+
"num_experts": 3,
742+
"num_experts_per_tok": 1,
743+
"moe_intermediate_size": 6,
744+
"shared_expert_intermediate_size": 6,
745+
}
746+
args = qwen3_5_moe.ModelArgs.from_dict(
747+
{"model_type": "qwen3_5_moe", "text_config": text_config}
748+
)
749+
model = qwen3_5_moe.Model(args)
750+
751+
prefix = "model.language_model.layers.0.mlp.experts"
752+
# Index 2 deliberately missing.
753+
gapped = {
754+
f"{prefix}.0.gate_proj.weight": mx.zeros((6, 4)),
755+
f"{prefix}.0.up_proj.weight": mx.zeros((6, 4)),
756+
f"{prefix}.0.down_proj.weight": mx.zeros((4, 6)),
757+
f"{prefix}.1.gate_proj.weight": mx.zeros((6, 4)),
758+
f"{prefix}.1.up_proj.weight": mx.zeros((6, 4)),
759+
f"{prefix}.1.down_proj.weight": mx.zeros((4, 6)),
760+
f"{prefix}.3.gate_proj.weight": mx.zeros((6, 4)),
761+
f"{prefix}.3.up_proj.weight": mx.zeros((6, 4)),
762+
f"{prefix}.3.down_proj.weight": mx.zeros((4, 6)),
763+
}
764+
with self.assertRaisesRegex(ValueError, "non-contiguous"):
765+
model.sanitize(dict(gapped))
766+
767+
def test_qwen3_5_moe_combined_format_still_splits_to_switch_mlp(self):
768+
# Regression guard: pre-stacked checkpoints (e.g. mlx-community Qwen3.5
769+
# / 3.6 redistributions) ship ``experts.gate_up_proj`` /
770+
# ``experts.down_proj`` as combined tensors. Sanitize must still split
771+
# them into ``switch_mlp.{gate,up,down}_proj.weight`` via the original
772+
# branch, untouched by the new per-expert path.
773+
from mlx_lm.models import qwen3_5_moe
774+
775+
text_config = {
776+
"model_type": "qwen3_5_moe_text",
777+
"hidden_size": 4,
778+
"intermediate_size": 8,
779+
"num_hidden_layers": 1,
780+
"num_attention_heads": 1,
781+
"num_key_value_heads": 1,
782+
"rms_norm_eps": 1e-5,
783+
"vocab_size": 16,
784+
"linear_num_value_heads": 1,
785+
"linear_num_key_heads": 1,
786+
"linear_key_head_dim": 4,
787+
"linear_value_head_dim": 4,
788+
"linear_conv_kernel_dim": 1,
789+
"full_attention_interval": 1,
790+
"tie_word_embeddings": False,
791+
"max_position_embeddings": 16,
792+
"num_experts": 2,
793+
"num_experts_per_tok": 1,
794+
"moe_intermediate_size": 6,
795+
"shared_expert_intermediate_size": 6,
796+
}
797+
args = qwen3_5_moe.ModelArgs.from_dict(
798+
{"model_type": "qwen3_5_moe", "text_config": text_config}
799+
)
800+
model = qwen3_5_moe.Model(args)
801+
802+
# Pre-stacked: gate_up has shape (num_experts, 2*intermediate, hidden);
803+
# down has shape (num_experts, hidden, intermediate).
804+
gate_up = mx.arange(2 * 12 * 4, dtype=mx.float32).reshape(2, 12, 4)
805+
down = mx.arange(2 * 4 * 6, dtype=mx.float32).reshape(2, 4, 6)
806+
sanitized = model.sanitize(
807+
{
808+
"model.language_model.layers.0.mlp.experts.gate_up_proj": gate_up,
809+
"model.language_model.layers.0.mlp.experts.down_proj": down,
810+
}
811+
)
812+
813+
out_prefix = "language_model.model.layers.0.mlp.switch_mlp"
814+
self.assertEqual(sanitized[f"{out_prefix}.gate_proj.weight"].shape, (2, 6, 4))
815+
self.assertEqual(sanitized[f"{out_prefix}.up_proj.weight"].shape, (2, 6, 4))
816+
self.assertEqual(sanitized[f"{out_prefix}.down_proj.weight"].shape, (2, 4, 6))
817+
self.assertTrue(
818+
mx.array_equal(sanitized[f"{out_prefix}.gate_proj.weight"], gate_up[:, :6, :])
819+
)
820+
self.assertTrue(
821+
mx.array_equal(sanitized[f"{out_prefix}.up_proj.weight"], gate_up[:, 6:, :])
822+
)
823+
self.assertTrue(mx.array_equal(sanitized[f"{out_prefix}.down_proj.weight"], down))
824+
# Combined keys must not leak through after split.
825+
self.assertNotIn("language_model.model.layers.0.mlp.experts.gate_up_proj", sanitized)
826+
self.assertNotIn("language_model.model.layers.0.mlp.experts.down_proj", sanitized)
827+
652828
def test_gemma4_convert_then_load_keeps_language_model_prefix(self):
653829
from mlx_lm.models import gemma4
654830

0 commit comments

Comments
 (0)