Skip to content

Commit cd039ff

Browse files
committed
Support per-expert MoE checkpoints in qwen3_5_moe.sanitize
Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs as one tensor per expert per projection: model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked (experts.gate_up_proj / experts.down_proj) and loads via the existing combined-format branch unchanged; only the FP8 release lands per-expert, likely because Qwen's FP8 quantization pipeline runs per-expert and the artifact is not re-stacked after quantization. Loading Qwen/Qwen3.6-35B-A3B-FP8 today fails strict load_weights with thousands of unexpected keys (30720 = 256 experts x 40 layers x 3 projections). Add a second branch in Qwen3_5MoeModel.sanitize that detects the per-expert layout and stacks the per-expert tensors along axis 0 into the same switch_mlp.* form the combined-format branch produces. Pre-stacked checkpoints (mlx-community redistributions, Qwen3.5-MoE, Qwen3.6 bf16 MoE) take the original branch unchanged. Defensive contiguity check raises ValueError on non-contiguous indices (e.g. shipping experts {0,1,3} but skipping 2) so a malformed checkpoint fails loud rather than silently dropping experts. Tests in tests/test_models.py: - test_qwen3_5_moe_per_expert_weights_stack_to_switch_mlp (positive) - test_qwen3_5_moe_per_expert_gap_raises (defensive) - test_qwen3_5_moe_combined_format_still_splits_to_switch_mlp (regression) Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
1 parent ed1fca4 commit cd039ff

2 files changed

Lines changed: 219 additions & 0 deletions

File tree

mlx_lm/models/qwen3_5_moe.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from dataclasses import dataclass
44

5+
import mlx.core as mx
6+
57
from .base import BaseModelArgs
68
from .qwen3_5 import Model as Qwen3_5Model
79

@@ -48,5 +50,44 @@ def sanitize(self, weights):
4850
new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop(
4951
f"{prefix}.experts.down_proj"
5052
)
53+
elif f"{prefix}.experts.0.gate_proj.weight" in new_weights:
54+
# Per-expert layout (Qwen/Qwen3.6-35B-A3B-FP8): one tensor per
55+
# expert per projection. The bf16 master Qwen/Qwen3.6-35B-A3B
56+
# is already pre-stacked and falls through the combined-format
57+
# branch above unchanged. Collect all matching keys for this
58+
# layer, validate the index range is contiguous from 0, then
59+
# stack along axis 0 into the same shape the combined-format
60+
# branch produces.
61+
experts_prefix = f"{prefix}.experts."
62+
gate_suffix = ".gate_proj.weight"
63+
indices = set()
64+
for key in new_weights:
65+
if key.startswith(experts_prefix) and key.endswith(gate_suffix):
66+
tail = key[len(experts_prefix) : -len(gate_suffix)]
67+
if tail.isdigit():
68+
indices.add(int(tail))
69+
expected = set(range(len(indices)))
70+
if indices != expected:
71+
missing = sorted(expected - indices)
72+
extra = sorted(indices - expected)
73+
raise ValueError(
74+
f"Per-expert MoE weights at {prefix}.experts have "
75+
f"non-contiguous indices: missing={missing}, "
76+
f"unexpected={extra}."
77+
)
78+
gates, ups, downs = [], [], []
79+
for e in range(len(indices)):
80+
gates.append(
81+
new_weights.pop(f"{prefix}.experts.{e}.gate_proj.weight")
82+
)
83+
ups.append(
84+
new_weights.pop(f"{prefix}.experts.{e}.up_proj.weight")
85+
)
86+
downs.append(
87+
new_weights.pop(f"{prefix}.experts.{e}.down_proj.weight")
88+
)
89+
new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = mx.stack(gates)
90+
new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = mx.stack(ups)
91+
new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = mx.stack(downs)
5192

5293
return self.language_model.sanitize(new_weights)

tests/test_models.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,184 @@ def test_qwen3_5_family_convert_then_load_norm_not_shift_twice(self):
649649
mx.array_equal(loaded[mlx_norm_key], converted[mlx_norm_key])
650650
)
651651

652+
def test_qwen3_5_moe_per_expert_weights_stack_to_switch_mlp(self):
653+
# Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs as one tensor per expert
654+
# per projection rather than as a single combined
655+
# ``experts.gate_up_proj`` / ``experts.down_proj`` tensor (the bf16
656+
# master Qwen/Qwen3.6-35B-A3B is already pre-stacked). The sanitize
657+
# step must stack the per-expert tensors into the combined
658+
# ``switch_mlp.*`` form that ``Qwen3_5MoeSparseMlp.load_weights``
659+
# expects.
660+
from mlx_lm.models import qwen3_5_moe
661+
662+
text_config = {
663+
"model_type": "qwen3_5_moe_text",
664+
"hidden_size": 4,
665+
"intermediate_size": 8,
666+
"num_hidden_layers": 1,
667+
"num_attention_heads": 1,
668+
"num_key_value_heads": 1,
669+
"rms_norm_eps": 1e-5,
670+
"vocab_size": 16,
671+
"linear_num_value_heads": 1,
672+
"linear_num_key_heads": 1,
673+
"linear_key_head_dim": 4,
674+
"linear_value_head_dim": 4,
675+
"linear_conv_kernel_dim": 1,
676+
"full_attention_interval": 1,
677+
"tie_word_embeddings": False,
678+
"max_position_embeddings": 16,
679+
"num_experts": 2,
680+
"num_experts_per_tok": 1,
681+
"moe_intermediate_size": 6,
682+
"shared_expert_intermediate_size": 6,
683+
}
684+
args = qwen3_5_moe.ModelArgs.from_dict(
685+
{"model_type": "qwen3_5_moe", "text_config": text_config}
686+
)
687+
model = qwen3_5_moe.Model(args)
688+
689+
prefix = "model.language_model.layers.0.mlp.experts"
690+
per_expert = {
691+
f"{prefix}.0.gate_proj.weight": mx.full((6, 4), 1.0),
692+
f"{prefix}.0.up_proj.weight": mx.full((6, 4), 2.0),
693+
f"{prefix}.0.down_proj.weight": mx.full((4, 6), 3.0),
694+
f"{prefix}.1.gate_proj.weight": mx.full((6, 4), 4.0),
695+
f"{prefix}.1.up_proj.weight": mx.full((6, 4), 5.0),
696+
f"{prefix}.1.down_proj.weight": mx.full((4, 6), 6.0),
697+
}
698+
sanitized = model.sanitize(dict(per_expert))
699+
700+
out_prefix = "language_model.model.layers.0.mlp.switch_mlp"
701+
gate_key = f"{out_prefix}.gate_proj.weight"
702+
up_key = f"{out_prefix}.up_proj.weight"
703+
down_key = f"{out_prefix}.down_proj.weight"
704+
705+
self.assertIn(gate_key, sanitized)
706+
self.assertIn(up_key, sanitized)
707+
self.assertIn(down_key, sanitized)
708+
self.assertEqual(sanitized[gate_key].shape, (2, 6, 4))
709+
self.assertEqual(sanitized[up_key].shape, (2, 6, 4))
710+
self.assertEqual(sanitized[down_key].shape, (2, 4, 6))
711+
# Per-expert keys must not leak through.
712+
self.assertFalse(any(".experts.0." in k for k in sanitized))
713+
self.assertFalse(any(".experts.1." in k for k in sanitized))
714+
# Stacking preserves per-expert content along axis 0.
715+
self.assertTrue(mx.array_equal(sanitized[gate_key][0], per_expert[f"{prefix}.0.gate_proj.weight"]))
716+
self.assertTrue(mx.array_equal(sanitized[gate_key][1], per_expert[f"{prefix}.1.gate_proj.weight"]))
717+
self.assertTrue(mx.array_equal(sanitized[down_key][1], per_expert[f"{prefix}.1.down_proj.weight"]))
718+
719+
def test_qwen3_5_moe_per_expert_gap_raises(self):
720+
# Defensive: if a per-expert checkpoint has non-contiguous expert indices
721+
# (e.g. ships 0, 1, 3 but not 2), sanitize must fail loudly rather than
722+
# silently dropping expert 3 — which would replicate the strict-load
723+
# failure this branch is meant to fix.
724+
from mlx_lm.models import qwen3_5_moe
725+
726+
text_config = {
727+
"model_type": "qwen3_5_moe_text",
728+
"hidden_size": 4,
729+
"intermediate_size": 8,
730+
"num_hidden_layers": 1,
731+
"num_attention_heads": 1,
732+
"num_key_value_heads": 1,
733+
"rms_norm_eps": 1e-5,
734+
"vocab_size": 16,
735+
"linear_num_value_heads": 1,
736+
"linear_num_key_heads": 1,
737+
"linear_key_head_dim": 4,
738+
"linear_value_head_dim": 4,
739+
"linear_conv_kernel_dim": 1,
740+
"full_attention_interval": 1,
741+
"tie_word_embeddings": False,
742+
"max_position_embeddings": 16,
743+
"num_experts": 3,
744+
"num_experts_per_tok": 1,
745+
"moe_intermediate_size": 6,
746+
"shared_expert_intermediate_size": 6,
747+
}
748+
args = qwen3_5_moe.ModelArgs.from_dict(
749+
{"model_type": "qwen3_5_moe", "text_config": text_config}
750+
)
751+
model = qwen3_5_moe.Model(args)
752+
753+
prefix = "model.language_model.layers.0.mlp.experts"
754+
# Index 2 deliberately missing.
755+
gapped = {
756+
f"{prefix}.0.gate_proj.weight": mx.zeros((6, 4)),
757+
f"{prefix}.0.up_proj.weight": mx.zeros((6, 4)),
758+
f"{prefix}.0.down_proj.weight": mx.zeros((4, 6)),
759+
f"{prefix}.1.gate_proj.weight": mx.zeros((6, 4)),
760+
f"{prefix}.1.up_proj.weight": mx.zeros((6, 4)),
761+
f"{prefix}.1.down_proj.weight": mx.zeros((4, 6)),
762+
f"{prefix}.3.gate_proj.weight": mx.zeros((6, 4)),
763+
f"{prefix}.3.up_proj.weight": mx.zeros((6, 4)),
764+
f"{prefix}.3.down_proj.weight": mx.zeros((4, 6)),
765+
}
766+
with self.assertRaisesRegex(ValueError, "non-contiguous"):
767+
model.sanitize(dict(gapped))
768+
769+
def test_qwen3_5_moe_combined_format_still_splits_to_switch_mlp(self):
770+
# Regression guard: pre-stacked checkpoints (e.g. mlx-community Qwen3.5
771+
# / 3.6 redistributions) ship ``experts.gate_up_proj`` /
772+
# ``experts.down_proj`` as combined tensors. Sanitize must still split
773+
# them into ``switch_mlp.{gate,up,down}_proj.weight`` via the original
774+
# branch, untouched by the new per-expert path.
775+
from mlx_lm.models import qwen3_5_moe
776+
777+
text_config = {
778+
"model_type": "qwen3_5_moe_text",
779+
"hidden_size": 4,
780+
"intermediate_size": 8,
781+
"num_hidden_layers": 1,
782+
"num_attention_heads": 1,
783+
"num_key_value_heads": 1,
784+
"rms_norm_eps": 1e-5,
785+
"vocab_size": 16,
786+
"linear_num_value_heads": 1,
787+
"linear_num_key_heads": 1,
788+
"linear_key_head_dim": 4,
789+
"linear_value_head_dim": 4,
790+
"linear_conv_kernel_dim": 1,
791+
"full_attention_interval": 1,
792+
"tie_word_embeddings": False,
793+
"max_position_embeddings": 16,
794+
"num_experts": 2,
795+
"num_experts_per_tok": 1,
796+
"moe_intermediate_size": 6,
797+
"shared_expert_intermediate_size": 6,
798+
}
799+
args = qwen3_5_moe.ModelArgs.from_dict(
800+
{"model_type": "qwen3_5_moe", "text_config": text_config}
801+
)
802+
model = qwen3_5_moe.Model(args)
803+
804+
# Pre-stacked: gate_up has shape (num_experts, 2*intermediate, hidden);
805+
# down has shape (num_experts, hidden, intermediate).
806+
gate_up = mx.arange(2 * 12 * 4, dtype=mx.float32).reshape(2, 12, 4)
807+
down = mx.arange(2 * 4 * 6, dtype=mx.float32).reshape(2, 4, 6)
808+
sanitized = model.sanitize(
809+
{
810+
"model.language_model.layers.0.mlp.experts.gate_up_proj": gate_up,
811+
"model.language_model.layers.0.mlp.experts.down_proj": down,
812+
}
813+
)
814+
815+
out_prefix = "language_model.model.layers.0.mlp.switch_mlp"
816+
self.assertEqual(sanitized[f"{out_prefix}.gate_proj.weight"].shape, (2, 6, 4))
817+
self.assertEqual(sanitized[f"{out_prefix}.up_proj.weight"].shape, (2, 6, 4))
818+
self.assertEqual(sanitized[f"{out_prefix}.down_proj.weight"].shape, (2, 4, 6))
819+
self.assertTrue(
820+
mx.array_equal(sanitized[f"{out_prefix}.gate_proj.weight"], gate_up[:, :6, :])
821+
)
822+
self.assertTrue(
823+
mx.array_equal(sanitized[f"{out_prefix}.up_proj.weight"], gate_up[:, 6:, :])
824+
)
825+
self.assertTrue(mx.array_equal(sanitized[f"{out_prefix}.down_proj.weight"], down))
826+
# Combined keys must not leak through after split.
827+
self.assertNotIn("language_model.model.layers.0.mlp.experts.gate_up_proj", sanitized)
828+
self.assertNotIn("language_model.model.layers.0.mlp.experts.down_proj", sanitized)
829+
652830
def test_gemma4_convert_then_load_keeps_language_model_prefix(self):
653831
from mlx_lm.models import gemma4
654832

0 commit comments

Comments
 (0)