@@ -649,6 +649,182 @@ def test_qwen3_5_family_convert_then_load_norm_not_shift_twice(self):
649649 mx .array_equal (loaded [mlx_norm_key ], converted [mlx_norm_key ])
650650 )
651651
652+ def test_qwen3_5_moe_per_expert_weights_stack_to_switch_mlp (self ):
653+ # Qwen3.6 MoE checkpoints (e.g. Qwen/Qwen3.6-35B-A3B-FP8) ship expert
654+ # MLPs as one tensor per expert per projection rather than as a single
655+ # combined ``experts.gate_up_proj`` / ``experts.down_proj`` tensor.
656+ # The sanitize step must stack them into the combined ``switch_mlp.*``
657+ # form that ``Qwen3_5MoeSparseMlp.load_weights`` expects.
658+ from mlx_lm .models import qwen3_5_moe
659+
660+ text_config = {
661+ "model_type" : "qwen3_5_moe_text" ,
662+ "hidden_size" : 4 ,
663+ "intermediate_size" : 8 ,
664+ "num_hidden_layers" : 1 ,
665+ "num_attention_heads" : 1 ,
666+ "num_key_value_heads" : 1 ,
667+ "rms_norm_eps" : 1e-5 ,
668+ "vocab_size" : 16 ,
669+ "linear_num_value_heads" : 1 ,
670+ "linear_num_key_heads" : 1 ,
671+ "linear_key_head_dim" : 4 ,
672+ "linear_value_head_dim" : 4 ,
673+ "linear_conv_kernel_dim" : 1 ,
674+ "full_attention_interval" : 1 ,
675+ "tie_word_embeddings" : False ,
676+ "max_position_embeddings" : 16 ,
677+ "num_experts" : 2 ,
678+ "num_experts_per_tok" : 1 ,
679+ "moe_intermediate_size" : 6 ,
680+ "shared_expert_intermediate_size" : 6 ,
681+ }
682+ args = qwen3_5_moe .ModelArgs .from_dict (
683+ {"model_type" : "qwen3_5_moe" , "text_config" : text_config }
684+ )
685+ model = qwen3_5_moe .Model (args )
686+
687+ prefix = "model.language_model.layers.0.mlp.experts"
688+ per_expert = {
689+ f"{ prefix } .0.gate_proj.weight" : mx .full ((6 , 4 ), 1.0 ),
690+ f"{ prefix } .0.up_proj.weight" : mx .full ((6 , 4 ), 2.0 ),
691+ f"{ prefix } .0.down_proj.weight" : mx .full ((4 , 6 ), 3.0 ),
692+ f"{ prefix } .1.gate_proj.weight" : mx .full ((6 , 4 ), 4.0 ),
693+ f"{ prefix } .1.up_proj.weight" : mx .full ((6 , 4 ), 5.0 ),
694+ f"{ prefix } .1.down_proj.weight" : mx .full ((4 , 6 ), 6.0 ),
695+ }
696+ sanitized = model .sanitize (dict (per_expert ))
697+
698+ out_prefix = "language_model.model.layers.0.mlp.switch_mlp"
699+ gate_key = f"{ out_prefix } .gate_proj.weight"
700+ up_key = f"{ out_prefix } .up_proj.weight"
701+ down_key = f"{ out_prefix } .down_proj.weight"
702+
703+ self .assertIn (gate_key , sanitized )
704+ self .assertIn (up_key , sanitized )
705+ self .assertIn (down_key , sanitized )
706+ self .assertEqual (sanitized [gate_key ].shape , (2 , 6 , 4 ))
707+ self .assertEqual (sanitized [up_key ].shape , (2 , 6 , 4 ))
708+ self .assertEqual (sanitized [down_key ].shape , (2 , 4 , 6 ))
709+ # Per-expert keys must not leak through.
710+ self .assertFalse (any (".experts.0." in k for k in sanitized ))
711+ self .assertFalse (any (".experts.1." in k for k in sanitized ))
712+ # Stacking preserves per-expert content along axis 0.
713+ self .assertTrue (mx .array_equal (sanitized [gate_key ][0 ], per_expert [f"{ prefix } .0.gate_proj.weight" ]))
714+ self .assertTrue (mx .array_equal (sanitized [gate_key ][1 ], per_expert [f"{ prefix } .1.gate_proj.weight" ]))
715+ self .assertTrue (mx .array_equal (sanitized [down_key ][1 ], per_expert [f"{ prefix } .1.down_proj.weight" ]))
716+
717+ def test_qwen3_5_moe_per_expert_gap_raises (self ):
718+ # Defensive: if a per-expert checkpoint has non-contiguous expert indices
719+ # (e.g. ships 0, 1, 3 but not 2), sanitize must fail loudly rather than
720+ # silently dropping expert 3 — which would replicate the strict-load
721+ # failure this branch is meant to fix.
722+ from mlx_lm .models import qwen3_5_moe
723+
724+ text_config = {
725+ "model_type" : "qwen3_5_moe_text" ,
726+ "hidden_size" : 4 ,
727+ "intermediate_size" : 8 ,
728+ "num_hidden_layers" : 1 ,
729+ "num_attention_heads" : 1 ,
730+ "num_key_value_heads" : 1 ,
731+ "rms_norm_eps" : 1e-5 ,
732+ "vocab_size" : 16 ,
733+ "linear_num_value_heads" : 1 ,
734+ "linear_num_key_heads" : 1 ,
735+ "linear_key_head_dim" : 4 ,
736+ "linear_value_head_dim" : 4 ,
737+ "linear_conv_kernel_dim" : 1 ,
738+ "full_attention_interval" : 1 ,
739+ "tie_word_embeddings" : False ,
740+ "max_position_embeddings" : 16 ,
741+ "num_experts" : 3 ,
742+ "num_experts_per_tok" : 1 ,
743+ "moe_intermediate_size" : 6 ,
744+ "shared_expert_intermediate_size" : 6 ,
745+ }
746+ args = qwen3_5_moe .ModelArgs .from_dict (
747+ {"model_type" : "qwen3_5_moe" , "text_config" : text_config }
748+ )
749+ model = qwen3_5_moe .Model (args )
750+
751+ prefix = "model.language_model.layers.0.mlp.experts"
752+ # Index 2 deliberately missing.
753+ gapped = {
754+ f"{ prefix } .0.gate_proj.weight" : mx .zeros ((6 , 4 )),
755+ f"{ prefix } .0.up_proj.weight" : mx .zeros ((6 , 4 )),
756+ f"{ prefix } .0.down_proj.weight" : mx .zeros ((4 , 6 )),
757+ f"{ prefix } .1.gate_proj.weight" : mx .zeros ((6 , 4 )),
758+ f"{ prefix } .1.up_proj.weight" : mx .zeros ((6 , 4 )),
759+ f"{ prefix } .1.down_proj.weight" : mx .zeros ((4 , 6 )),
760+ f"{ prefix } .3.gate_proj.weight" : mx .zeros ((6 , 4 )),
761+ f"{ prefix } .3.up_proj.weight" : mx .zeros ((6 , 4 )),
762+ f"{ prefix } .3.down_proj.weight" : mx .zeros ((4 , 6 )),
763+ }
764+ with self .assertRaisesRegex (ValueError , "non-contiguous" ):
765+ model .sanitize (dict (gapped ))
766+
767+ def test_qwen3_5_moe_combined_format_still_splits_to_switch_mlp (self ):
768+ # Regression guard: pre-stacked checkpoints (e.g. mlx-community Qwen3.5
769+ # / 3.6 redistributions) ship ``experts.gate_up_proj`` /
770+ # ``experts.down_proj`` as combined tensors. Sanitize must still split
771+ # them into ``switch_mlp.{gate,up,down}_proj.weight`` via the original
772+ # branch, untouched by the new per-expert path.
773+ from mlx_lm .models import qwen3_5_moe
774+
775+ text_config = {
776+ "model_type" : "qwen3_5_moe_text" ,
777+ "hidden_size" : 4 ,
778+ "intermediate_size" : 8 ,
779+ "num_hidden_layers" : 1 ,
780+ "num_attention_heads" : 1 ,
781+ "num_key_value_heads" : 1 ,
782+ "rms_norm_eps" : 1e-5 ,
783+ "vocab_size" : 16 ,
784+ "linear_num_value_heads" : 1 ,
785+ "linear_num_key_heads" : 1 ,
786+ "linear_key_head_dim" : 4 ,
787+ "linear_value_head_dim" : 4 ,
788+ "linear_conv_kernel_dim" : 1 ,
789+ "full_attention_interval" : 1 ,
790+ "tie_word_embeddings" : False ,
791+ "max_position_embeddings" : 16 ,
792+ "num_experts" : 2 ,
793+ "num_experts_per_tok" : 1 ,
794+ "moe_intermediate_size" : 6 ,
795+ "shared_expert_intermediate_size" : 6 ,
796+ }
797+ args = qwen3_5_moe .ModelArgs .from_dict (
798+ {"model_type" : "qwen3_5_moe" , "text_config" : text_config }
799+ )
800+ model = qwen3_5_moe .Model (args )
801+
802+ # Pre-stacked: gate_up has shape (num_experts, 2*intermediate, hidden);
803+ # down has shape (num_experts, hidden, intermediate).
804+ gate_up = mx .arange (2 * 12 * 4 , dtype = mx .float32 ).reshape (2 , 12 , 4 )
805+ down = mx .arange (2 * 4 * 6 , dtype = mx .float32 ).reshape (2 , 4 , 6 )
806+ sanitized = model .sanitize (
807+ {
808+ "model.language_model.layers.0.mlp.experts.gate_up_proj" : gate_up ,
809+ "model.language_model.layers.0.mlp.experts.down_proj" : down ,
810+ }
811+ )
812+
813+ out_prefix = "language_model.model.layers.0.mlp.switch_mlp"
814+ self .assertEqual (sanitized [f"{ out_prefix } .gate_proj.weight" ].shape , (2 , 6 , 4 ))
815+ self .assertEqual (sanitized [f"{ out_prefix } .up_proj.weight" ].shape , (2 , 6 , 4 ))
816+ self .assertEqual (sanitized [f"{ out_prefix } .down_proj.weight" ].shape , (2 , 4 , 6 ))
817+ self .assertTrue (
818+ mx .array_equal (sanitized [f"{ out_prefix } .gate_proj.weight" ], gate_up [:, :6 , :])
819+ )
820+ self .assertTrue (
821+ mx .array_equal (sanitized [f"{ out_prefix } .up_proj.weight" ], gate_up [:, 6 :, :])
822+ )
823+ self .assertTrue (mx .array_equal (sanitized [f"{ out_prefix } .down_proj.weight" ], down ))
824+ # Combined keys must not leak through after split.
825+ self .assertNotIn ("language_model.model.layers.0.mlp.experts.gate_up_proj" , sanitized )
826+ self .assertNotIn ("language_model.model.layers.0.mlp.experts.down_proj" , sanitized )
827+
652828 def test_gemma4_convert_then_load_keeps_language_model_prefix (self ):
653829 from mlx_lm .models import gemma4
654830
0 commit comments