@@ -649,6 +649,184 @@ def test_qwen3_5_family_convert_then_load_norm_not_shift_twice(self):
649649 mx .array_equal (loaded [mlx_norm_key ], converted [mlx_norm_key ])
650650 )
651651
652+ def test_qwen3_5_moe_per_expert_weights_stack_to_switch_mlp (self ):
653+ # Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs as one tensor per expert
654+ # per projection rather than as a single combined
655+ # ``experts.gate_up_proj`` / ``experts.down_proj`` tensor (the bf16
656+ # master Qwen/Qwen3.6-35B-A3B is already pre-stacked). The sanitize
657+ # step must stack the per-expert tensors into the combined
658+ # ``switch_mlp.*`` form that ``Qwen3_5MoeSparseMlp.load_weights``
659+ # expects.
660+ from mlx_lm .models import qwen3_5_moe
661+
662+ text_config = {
663+ "model_type" : "qwen3_5_moe_text" ,
664+ "hidden_size" : 4 ,
665+ "intermediate_size" : 8 ,
666+ "num_hidden_layers" : 1 ,
667+ "num_attention_heads" : 1 ,
668+ "num_key_value_heads" : 1 ,
669+ "rms_norm_eps" : 1e-5 ,
670+ "vocab_size" : 16 ,
671+ "linear_num_value_heads" : 1 ,
672+ "linear_num_key_heads" : 1 ,
673+ "linear_key_head_dim" : 4 ,
674+ "linear_value_head_dim" : 4 ,
675+ "linear_conv_kernel_dim" : 1 ,
676+ "full_attention_interval" : 1 ,
677+ "tie_word_embeddings" : False ,
678+ "max_position_embeddings" : 16 ,
679+ "num_experts" : 2 ,
680+ "num_experts_per_tok" : 1 ,
681+ "moe_intermediate_size" : 6 ,
682+ "shared_expert_intermediate_size" : 6 ,
683+ }
684+ args = qwen3_5_moe .ModelArgs .from_dict (
685+ {"model_type" : "qwen3_5_moe" , "text_config" : text_config }
686+ )
687+ model = qwen3_5_moe .Model (args )
688+
689+ prefix = "model.language_model.layers.0.mlp.experts"
690+ per_expert = {
691+ f"{ prefix } .0.gate_proj.weight" : mx .full ((6 , 4 ), 1.0 ),
692+ f"{ prefix } .0.up_proj.weight" : mx .full ((6 , 4 ), 2.0 ),
693+ f"{ prefix } .0.down_proj.weight" : mx .full ((4 , 6 ), 3.0 ),
694+ f"{ prefix } .1.gate_proj.weight" : mx .full ((6 , 4 ), 4.0 ),
695+ f"{ prefix } .1.up_proj.weight" : mx .full ((6 , 4 ), 5.0 ),
696+ f"{ prefix } .1.down_proj.weight" : mx .full ((4 , 6 ), 6.0 ),
697+ }
698+ sanitized = model .sanitize (dict (per_expert ))
699+
700+ out_prefix = "language_model.model.layers.0.mlp.switch_mlp"
701+ gate_key = f"{ out_prefix } .gate_proj.weight"
702+ up_key = f"{ out_prefix } .up_proj.weight"
703+ down_key = f"{ out_prefix } .down_proj.weight"
704+
705+ self .assertIn (gate_key , sanitized )
706+ self .assertIn (up_key , sanitized )
707+ self .assertIn (down_key , sanitized )
708+ self .assertEqual (sanitized [gate_key ].shape , (2 , 6 , 4 ))
709+ self .assertEqual (sanitized [up_key ].shape , (2 , 6 , 4 ))
710+ self .assertEqual (sanitized [down_key ].shape , (2 , 4 , 6 ))
711+ # Per-expert keys must not leak through.
712+ self .assertFalse (any (".experts.0." in k for k in sanitized ))
713+ self .assertFalse (any (".experts.1." in k for k in sanitized ))
714+ # Stacking preserves per-expert content along axis 0.
715+ self .assertTrue (mx .array_equal (sanitized [gate_key ][0 ], per_expert [f"{ prefix } .0.gate_proj.weight" ]))
716+ self .assertTrue (mx .array_equal (sanitized [gate_key ][1 ], per_expert [f"{ prefix } .1.gate_proj.weight" ]))
717+ self .assertTrue (mx .array_equal (sanitized [down_key ][1 ], per_expert [f"{ prefix } .1.down_proj.weight" ]))
718+
719+ def test_qwen3_5_moe_per_expert_gap_raises (self ):
720+ # Defensive: if a per-expert checkpoint has non-contiguous expert indices
721+ # (e.g. ships 0, 1, 3 but not 2), sanitize must fail loudly rather than
722+ # silently dropping expert 3 — which would replicate the strict-load
723+ # failure this branch is meant to fix.
724+ from mlx_lm .models import qwen3_5_moe
725+
726+ text_config = {
727+ "model_type" : "qwen3_5_moe_text" ,
728+ "hidden_size" : 4 ,
729+ "intermediate_size" : 8 ,
730+ "num_hidden_layers" : 1 ,
731+ "num_attention_heads" : 1 ,
732+ "num_key_value_heads" : 1 ,
733+ "rms_norm_eps" : 1e-5 ,
734+ "vocab_size" : 16 ,
735+ "linear_num_value_heads" : 1 ,
736+ "linear_num_key_heads" : 1 ,
737+ "linear_key_head_dim" : 4 ,
738+ "linear_value_head_dim" : 4 ,
739+ "linear_conv_kernel_dim" : 1 ,
740+ "full_attention_interval" : 1 ,
741+ "tie_word_embeddings" : False ,
742+ "max_position_embeddings" : 16 ,
743+ "num_experts" : 3 ,
744+ "num_experts_per_tok" : 1 ,
745+ "moe_intermediate_size" : 6 ,
746+ "shared_expert_intermediate_size" : 6 ,
747+ }
748+ args = qwen3_5_moe .ModelArgs .from_dict (
749+ {"model_type" : "qwen3_5_moe" , "text_config" : text_config }
750+ )
751+ model = qwen3_5_moe .Model (args )
752+
753+ prefix = "model.language_model.layers.0.mlp.experts"
754+ # Index 2 deliberately missing.
755+ gapped = {
756+ f"{ prefix } .0.gate_proj.weight" : mx .zeros ((6 , 4 )),
757+ f"{ prefix } .0.up_proj.weight" : mx .zeros ((6 , 4 )),
758+ f"{ prefix } .0.down_proj.weight" : mx .zeros ((4 , 6 )),
759+ f"{ prefix } .1.gate_proj.weight" : mx .zeros ((6 , 4 )),
760+ f"{ prefix } .1.up_proj.weight" : mx .zeros ((6 , 4 )),
761+ f"{ prefix } .1.down_proj.weight" : mx .zeros ((4 , 6 )),
762+ f"{ prefix } .3.gate_proj.weight" : mx .zeros ((6 , 4 )),
763+ f"{ prefix } .3.up_proj.weight" : mx .zeros ((6 , 4 )),
764+ f"{ prefix } .3.down_proj.weight" : mx .zeros ((4 , 6 )),
765+ }
766+ with self .assertRaisesRegex (ValueError , "non-contiguous" ):
767+ model .sanitize (dict (gapped ))
768+
769+ def test_qwen3_5_moe_combined_format_still_splits_to_switch_mlp (self ):
770+ # Regression guard: pre-stacked checkpoints (e.g. mlx-community Qwen3.5
771+ # / 3.6 redistributions) ship ``experts.gate_up_proj`` /
772+ # ``experts.down_proj`` as combined tensors. Sanitize must still split
773+ # them into ``switch_mlp.{gate,up,down}_proj.weight`` via the original
774+ # branch, untouched by the new per-expert path.
775+ from mlx_lm .models import qwen3_5_moe
776+
777+ text_config = {
778+ "model_type" : "qwen3_5_moe_text" ,
779+ "hidden_size" : 4 ,
780+ "intermediate_size" : 8 ,
781+ "num_hidden_layers" : 1 ,
782+ "num_attention_heads" : 1 ,
783+ "num_key_value_heads" : 1 ,
784+ "rms_norm_eps" : 1e-5 ,
785+ "vocab_size" : 16 ,
786+ "linear_num_value_heads" : 1 ,
787+ "linear_num_key_heads" : 1 ,
788+ "linear_key_head_dim" : 4 ,
789+ "linear_value_head_dim" : 4 ,
790+ "linear_conv_kernel_dim" : 1 ,
791+ "full_attention_interval" : 1 ,
792+ "tie_word_embeddings" : False ,
793+ "max_position_embeddings" : 16 ,
794+ "num_experts" : 2 ,
795+ "num_experts_per_tok" : 1 ,
796+ "moe_intermediate_size" : 6 ,
797+ "shared_expert_intermediate_size" : 6 ,
798+ }
799+ args = qwen3_5_moe .ModelArgs .from_dict (
800+ {"model_type" : "qwen3_5_moe" , "text_config" : text_config }
801+ )
802+ model = qwen3_5_moe .Model (args )
803+
804+ # Pre-stacked: gate_up has shape (num_experts, 2*intermediate, hidden);
805+ # down has shape (num_experts, hidden, intermediate).
806+ gate_up = mx .arange (2 * 12 * 4 , dtype = mx .float32 ).reshape (2 , 12 , 4 )
807+ down = mx .arange (2 * 4 * 6 , dtype = mx .float32 ).reshape (2 , 4 , 6 )
808+ sanitized = model .sanitize (
809+ {
810+ "model.language_model.layers.0.mlp.experts.gate_up_proj" : gate_up ,
811+ "model.language_model.layers.0.mlp.experts.down_proj" : down ,
812+ }
813+ )
814+
815+ out_prefix = "language_model.model.layers.0.mlp.switch_mlp"
816+ self .assertEqual (sanitized [f"{ out_prefix } .gate_proj.weight" ].shape , (2 , 6 , 4 ))
817+ self .assertEqual (sanitized [f"{ out_prefix } .up_proj.weight" ].shape , (2 , 6 , 4 ))
818+ self .assertEqual (sanitized [f"{ out_prefix } .down_proj.weight" ].shape , (2 , 4 , 6 ))
819+ self .assertTrue (
820+ mx .array_equal (sanitized [f"{ out_prefix } .gate_proj.weight" ], gate_up [:, :6 , :])
821+ )
822+ self .assertTrue (
823+ mx .array_equal (sanitized [f"{ out_prefix } .up_proj.weight" ], gate_up [:, 6 :, :])
824+ )
825+ self .assertTrue (mx .array_equal (sanitized [f"{ out_prefix } .down_proj.weight" ], down ))
826+ # Combined keys must not leak through after split.
827+ self .assertNotIn ("language_model.model.layers.0.mlp.experts.gate_up_proj" , sanitized )
828+ self .assertNotIn ("language_model.model.layers.0.mlp.experts.down_proj" , sanitized )
829+
652830 def test_gemma4_convert_then_load_keeps_language_model_prefix (self ):
653831 from mlx_lm .models import gemma4
654832
0 commit comments