Skip to content

Support per-expert MoE checkpoints in qwen3_5_moe.sanitize, plus FP8 dequant#1224

Open
sdayal wants to merge 1 commit intoml-explore:mainfrom
sdayal:qwen3_5_moe-per-expert-sanitize
Open

Support per-expert MoE checkpoints in qwen3_5_moe.sanitize, plus FP8 dequant#1224
sdayal wants to merge 1 commit intoml-explore:mainfrom
sdayal:qwen3_5_moe-per-expert-sanitize

Conversation

@sdayal
Copy link
Copy Markdown

@sdayal sdayal commented Apr 28, 2026

Summary

Qwen-org's canonical FP8 release Qwen/Qwen3.6-35B-A3B-FP8 fails strict load_weights on mlx-lm main with Received 61690 parameters not in model. This PR fixes both root causes and verifies the model loads end-to-end on vanilla mlx-lm with no external compat layer.

What's broken on main

The 61,690-key failure has two independent contributors:

Category Count Cause
Per-expert weight tensors (...experts.{E}.{gate,up,down}_proj.weight) 30,720 qwen3_5_moe.Model.sanitize only handles the combined-format layout (experts.gate_up_proj / experts.down_proj) produced by mlx_lm.convert.
Per-expert FP8 scale tensors (...weight_scale_inv) 30,720 No weight_scale_inv dequant exists in qwen3_5{,_moe}.py, even though five sibling files (deepseek_v3.py, deepseek_v32.py, minimax.py, mimo_v2_flash.py, ministral3.py) handle the same quant_method="fp8" block layout inline.
Non-expert FP8 scale tensors (attention, lm_head, etc.) 250 Same as above.

The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked and has no FP8 scales; it loads on main unchanged. Only the FP8 release exhibits both bugs.

Fixes

FP8 weight_scale_inv dequant + activation_scale drop in qwen3_5.Model.sanitize and qwen3_5_moe.Model.sanitize. Mirrors the inline pattern from deepseek_v3.py::sanitize::dequant: 128×128 block, mx.from_fp8(..., dtype=mx.bfloat16) followed by per-block scale broadcast and pad/unpad. No-op when no weight_scale_inv keys are present.

Per-expert MoE stacking branch in qwen3_5_moe.Model.sanitize, using a (scan → validate → walk) structure:

  1. Scan the weights dict for per-layer experts prefixes and their expert-index sets.
  2. Validate each prefix's index set is a contiguous {0, 1, …, N-1} (raises ValueError otherwise).
  3. Walk per-expert tensors in order, mx.stack along axis 0, and emit the combined switch_mlp.{gate,up,down}_proj.weight form downstream load_weights expects.

Pre-stacked checkpoints take the original if gate_up_key in new_weights: branch unchanged.

Defensive contiguity check raises ValueError on non-contiguous expert indices (e.g. {0, 1, 3} skipping 2) so a malformed checkpoint fails loud rather than silently dropping experts.

Tests

Five unit tests in tests/test_models.py:

  • test_qwen3_5_fp8_weight_scale_inv_dequantizes_in_sanitize — FP8 dequant via dense Model.sanitize, activation_scale drop verified.
  • test_qwen3_5_moe_fp8_weight_scale_inv_dequantizes_in_sanitize — same, via MoE Model.sanitize.
  • test_qwen3_5_moe_per_expert_weights_stack_to_switch_mlp — positive: per-expert input produces correctly stacked switch_mlp.* output.
  • test_qwen3_5_moe_per_expert_gap_raises — defensive: {0, 1, 3} raises ValueError with "non-contiguous" in the message.
  • test_qwen3_5_moe_combined_format_still_splits_to_switch_mlp — regression guard for the original combined-format branch.

All five pass.

End-to-end verification

Loaded Qwen/Qwen3.6-35B-A3B-FP8 via vanilla mlx_lm.load(...) on Apple Silicon Metal (M3 Max / 128 GB), no downstream compat layer:

  • Before: ValueError: Received 61690 parameters not in model
  • After: load succeeds; dequanted weight tensor at language_model.model.layers.0.linear_attn.in_proj_qkv.weight has shape (8192, 2048), dtype bfloat16, all finite, nonzero, norm 62.25 — values consistent with a real attention-projection weight matrix.

Pre-stacked checkpoints (mlx-community redistributions, Qwen/Qwen3.6-35B-A3B bf16 master, unsloth MLX variants) load unchanged via the combined-format branch, both shims dormant — confirms backward compatibility.

Affected / not affected

Source Layout Loads on main today Loads with this PR
Qwen/Qwen3.6-35B-A3B-FP8 per-expert + FP8 (E4M3, 128×128 block) ❌ 61,690 unrecognized keys
Qwen/Qwen3.6-35B-A3B (bf16 master) combined-format, pre-stacked ✅ unchanged
mlx-community/Qwen3.6-35B-A3B-{bf16,4bit,8bit,mxfp8,nvfp4,…} post-sanitize MLX-native (switch_mlp.*) ✅ unchanged
unsloth/Qwen3.6-35B-A3B-{UD-MLX-4bit,MLX-8bit,UD-MLX-3bit} post-sanitize MLX-native ✅ unchanged

Note on the inline FP8 dequant duplication

The FP8 dequant code added here is a near-byte-identical copy of the same block in models/{deepseek_v3,deepseek_v32,minimax,mimo_v2_flash,ministral3}.py. A future PR could extract a shared models/_fp8.py helper and consolidate across all six FP8-handling model files; happy to do that as a follow-up if preferred, but kept this PR scoped to "make Qwen/Qwen3.6-35B-A3B-FP8 load" for review tractability.

sdayal added a commit to sdayal/vllm-metal that referenced this pull request Apr 28, 2026
Qwen-org Qwen3.6 MoE checkpoints (e.g. Qwen/Qwen3.6-35B-A3B-FP8) ship
expert MLPs as one tensor per expert per projection:
  model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight

mlx_lm.qwen3_5_moe.Model.sanitize expects the combined-format layout
(experts.gate_up_proj / experts.down_proj) produced by mlx_lm.convert and
shipped by mlx-community redistributions. Loading a Qwen-org checkpoint
fails strict load_weights with thousands of unexpected keys (30720 for the
35B-A3B variant: 256 experts x 40 layers x 3 projections).

Extend the existing FP8 sanitize compat shim with a pre-step that detects
per-expert MoE tensors, validates the index range is contiguous from 0,
and stacks them (mx.stack along axis 0) into the combined experts.gate_up_proj
+ experts.down_proj form upstream sanitize already handles. Pre-stacked
checkpoints are unaffected (helper short-circuits when no per-expert keys
are present).

This is the downstream complement to ml-explore/mlx-lm#1224, which adds
the same stacking logic inline in qwen3_5_moe.Model.sanitize. When that
lands and vllm-metal's mlx-lm pin bumps past it, this shim can be removed.

Files:
- vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights helper,
  chained after FP8 dequant in the patched sanitize.
- docs/supported_models.md: update Qwen3.6 row note.
- tests/test_qwen36_smoke.py: opt-in e2e smoke gated on
  QWEN36_MOE_FP8_PATH env var (pytest skips by default; runs in 25s
  against a local Qwen3.6 MoE FP8 checkpoint).

Verified end-to-end on Qwen/Qwen3.6-35B-A3B-FP8: greedy decode of
"The capital of France is" returns " Paris, a city renowned for its iconic
landmarks such" with the hybrid SDPA + GDN linear attention path on Apple
Silicon Metal. Existing Qwen3.5 golden-token smoke (test_qwen35_smoke.py)
unchanged: 5/5 pass.
sdayal added a commit to sdayal/vllm-metal that referenced this pull request Apr 28, 2026
Qwen-org Qwen3.6 MoE checkpoints (e.g. Qwen/Qwen3.6-35B-A3B-FP8) ship
expert MLPs as one tensor per expert per projection:
  model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight

mlx_lm.qwen3_5_moe.Model.sanitize expects the combined-format layout
(experts.gate_up_proj / experts.down_proj) produced by mlx_lm.convert and
shipped by mlx-community redistributions. Loading a Qwen-org checkpoint
fails strict load_weights with thousands of unexpected keys (30720 for the
35B-A3B variant: 256 experts x 40 layers x 3 projections).

Extend the existing FP8 sanitize compat shim with a pre-step that detects
per-expert MoE tensors, validates the index range is contiguous from 0,
and stacks them (mx.stack along axis 0) into the combined experts.gate_up_proj
+ experts.down_proj form upstream sanitize already handles. Pre-stacked
checkpoints are unaffected (helper short-circuits when no per-expert keys
are present).

This is the downstream complement to ml-explore/mlx-lm#1224, which adds
the same stacking logic inline in qwen3_5_moe.Model.sanitize. When that
lands and vllm-metal's mlx-lm pin bumps past it, this shim can be removed.

Files:
- vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights helper,
  chained after FP8 dequant in the patched sanitize.
- docs/supported_models.md: update Qwen3.6 row note.
- tests/test_qwen36_smoke.py: opt-in e2e smoke gated on
  QWEN36_MOE_FP8_PATH env var (pytest skips by default; runs in 25s
  against a local Qwen3.6 MoE FP8 checkpoint).

Verified end-to-end on Qwen/Qwen3.6-35B-A3B-FP8: greedy decode of
"The capital of France is" returns " Paris, a city renowned for its iconic
landmarks such" with the hybrid SDPA + GDN linear attention path on Apple
Silicon Metal. Existing Qwen3.5 golden-token smoke (test_qwen35_smoke.py)
unchanged: 5/5 pass.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
@sdayal sdayal force-pushed the qwen3_5_moe-per-expert-sanitize branch 2 times, most recently from b9b247d to b01aef2 Compare April 28, 2026 17:43
sdayal added a commit to sdayal/vllm-metal that referenced this pull request Apr 28, 2026
Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs as one tensor per expert per
projection:
  model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight

The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked
(experts.gate_up_proj / experts.down_proj) and loads via the existing
combined-format branch in mlx_lm.qwen3_5_moe.Model.sanitize unchanged;
only the FP8 release lands per-expert, likely because Qwen's FP8
quantization pipeline runs per-expert and the artifact is not re-stacked.

Loading Qwen/Qwen3.6-35B-A3B-FP8 today fails strict load_weights with
thousands of unexpected keys (30720 for 35B-A3B: 256 experts x 40 layers
x 3 projections).

Extend the existing FP8 sanitize compat shim with a pre-step that detects
per-expert MoE tensors, validates the index range is contiguous from 0,
and stacks them (mx.stack along axis 0) into the combined
experts.gate_up_proj + experts.down_proj form upstream sanitize already
handles. Pre-stacked checkpoints are unaffected (helper short-circuits
when no per-expert keys are present).

This is the downstream complement to ml-explore/mlx-lm#1224, which adds
the same stacking logic inline in qwen3_5_moe.Model.sanitize. When that
lands and vllm-metal's mlx-lm pin bumps past it, this shim can be removed.

Files:
- vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights helper,
  chained after FP8 dequant in the patched sanitize.
- docs/supported_models.md: update Qwen3.6 row note.
- tests/test_qwen36_smoke.py: opt-in e2e smoke gated on
  QWEN36_MOE_FP8_PATH env var (pytest skips by default; runs in 25s
  against a local Qwen3.6 MoE FP8 checkpoint).

Verified end-to-end on Qwen/Qwen3.6-35B-A3B-FP8: greedy decode of
"The capital of France is" returns " Paris, a city renowned for its iconic
landmarks such" with the hybrid SDPA + GDN linear attention path on Apple
Silicon Metal. Existing Qwen3.5 golden-token smoke (test_qwen35_smoke.py)
unchanged: 5/5 pass.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
sdayal added a commit to sdayal/vllm-metal that referenced this pull request Apr 28, 2026
Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs per-expert
(model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight),
which mlx_lm.qwen3_5_moe.sanitize doesn't recognize. Loading on unpatched
mlx-lm fails strict load_weights with 30720 unexpected keys (256 experts
x 40 layers x 3 projections).

The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked
(experts.gate_up_proj / experts.down_proj) and loads via the existing
combined-format branch unchanged; only the FP8 release lands per-expert,
likely because Qwen's FP8 quantization pipeline runs per-expert.

Extend the FP8 sanitize compat shim with a pre-step that collects
per-expert tensors, validates contiguous indices from 0, and stacks them
into the combined form upstream sanitize already handles. Helper
short-circuits when no per-expert keys are present.

mlx-community publishes pre-stacked redistributions, but converting
Qwen-org's FP8 release ourselves needs a 35GB->70GB bf16 intermediate
that doesn't fit on Macs <=64 GB. This shim lets users load the
canonical artifact directly. Downstream complement to
ml-explore/mlx-lm#1224; removable once vllm-metal's mlx-lm pin bumps
past that merge.

Files:
- vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights, chained
  after FP8 dequant in the patched sanitize.
- docs/supported_models.md: update Qwen3.6 row note.
- tests/test_qwen36_smoke.py: opt-in e2e smoke gated on QWEN36_MOE_PATH
  (skipped by default; works against either FP8 or bf16 checkpoint).

Verified: Qwen/Qwen3.6-35B-A3B-FP8 generates correctly via the new
branch, Qwen/Qwen3.6-35B-A3B (bf16) via the unchanged combined-format
branch (both shims dormant), Qwen3.5 golden-token smoke 5/5 pass.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
@sdayal sdayal force-pushed the qwen3_5_moe-per-expert-sanitize branch from b01aef2 to cd039ff Compare April 28, 2026 18:10
sdayal added a commit to sdayal/vllm-metal that referenced this pull request Apr 28, 2026
Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs per-expert
(model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight),
which mlx_lm.qwen3_5_moe.sanitize doesn't recognize. Loading on unpatched
mlx-lm fails strict load_weights with 30720 unexpected keys (256 experts
x 40 layers x 3 projections).

The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked
(experts.gate_up_proj / experts.down_proj) and loads via the existing
combined-format branch unchanged; only the FP8 release lands per-expert,
likely because Qwen's FP8 quantization pipeline runs per-expert.

Extend the FP8 sanitize compat shim with a pre-step that collects
per-expert tensors, validates contiguous indices from 0, and stacks them
into the combined form upstream sanitize already handles. Helper
short-circuits when no per-expert keys are present.

mlx-community publishes pre-stacked redistributions, but converting
Qwen-org's FP8 release ourselves needs a 35GB->70GB bf16 intermediate
that doesn't fit on Macs <=64 GB. This shim lets users load the
canonical artifact directly. Downstream complement to
ml-explore/mlx-lm#1224; removable once vllm-metal's mlx-lm pin bumps
past that merge.

Files:
- vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights, chained
  after FP8 dequant in the patched sanitize.
- docs/supported_models.md: update Qwen3.6 row note.
- tests/test_qwen36_smoke.py: opt-in e2e smoke gated on QWEN36_MOE_PATH
  (skipped by default; works against either FP8 or bf16 checkpoint).

Verified: Qwen/Qwen3.6-35B-A3B-FP8 generates correctly via the new
branch, Qwen/Qwen3.6-35B-A3B (bf16) via the unchanged combined-format
branch (both shims dormant), Qwen3.5 golden-token smoke 5/5 pass.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
sdayal added a commit to sdayal/vllm-metal that referenced this pull request Apr 28, 2026
Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs per-expert
(model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight),
which mlx_lm.qwen3_5_moe.sanitize doesn't recognize. Loading on unpatched
mlx-lm fails strict load_weights with 30720 unexpected keys (256 experts
x 40 layers x 3 projections).

The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked
(experts.gate_up_proj / experts.down_proj) and loads via the existing
combined-format branch unchanged; only the FP8 release lands per-expert,
likely because Qwen's FP8 quantization pipeline runs per-expert.

Extend the FP8 sanitize compat shim with a pre-step that collects
per-expert tensors, validates contiguous indices from 0, and stacks them
into the combined form upstream sanitize already handles. Helper
short-circuits when no per-expert keys are present.

mlx-community publishes pre-stacked redistributions, but converting
Qwen-org's FP8 release ourselves needs a 35GB->70GB bf16 intermediate
that doesn't fit on Macs <=64 GB. This shim lets users load the
canonical artifact directly. Downstream complement to
ml-explore/mlx-lm#1224; removable once vllm-metal's mlx-lm pin bumps
past that merge.

Files:
- vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights, chained
  after FP8 dequant in the patched sanitize.
- docs/supported_models.md: update Qwen3.6 row note.
- tests/test_qwen36_smoke.py: opt-in e2e smoke gated on QWEN36_MOE_PATH
  (skipped by default; works against either FP8 or bf16 checkpoint).

Verified: Qwen/Qwen3.6-35B-A3B-FP8 generates correctly via the new
branch, Qwen/Qwen3.6-35B-A3B (bf16) via the unchanged combined-format
branch (both shims dormant), Qwen3.5 golden-token smoke 5/5 pass.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
sdayal added a commit to sdayal/vllm-metal that referenced this pull request Apr 29, 2026
Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs per-expert
(model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight),
which mlx_lm.qwen3_5_moe.sanitize doesn't recognize. Loading on unpatched
mlx-lm fails strict load_weights with 30720 unexpected keys (256 experts
x 40 layers x 3 projections).

The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked
(experts.gate_up_proj / experts.down_proj) and loads via the existing
combined-format branch unchanged; only the FP8 release lands per-expert,
likely because Qwen's FP8 quantization pipeline runs per-expert.

Add _stack_qwen36_moe_per_expert_weights mirroring the (scan -> validate
-> walk) structure of ml-explore/mlx-lm#1224. Split the sanitize patch
into per-class transforms so the MoE-only nature of the stacking is
self-evident:

- mlx_lm.models.qwen3_5.Model      -> _transform_dense (FP8 dequant only)
- mlx_lm.models.qwen3_5_moe.Model  -> _transform_moe   (FP8 dequant + stack)

mlx-community publishes pre-stacked redistributions, but converting
Qwen-org's FP8 release ourselves needs a 35GB->70GB bf16 intermediate
that doesn't fit on Macs <=64 GB. This shim lets users load the
canonical artifact directly. Removable once vllm-metal's mlx-lm pin
bumps past mlx-lm#1224.

Files:
- vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights helper,
  split sanitize patches by model class via transforms_by_module map.
- docs/supported_models.md: update Qwen3.6 row note + link this PR.
- tests/test_compat.py: 4 new unit tests covering positive (per-expert
  -> combined), regression (pre-stacked no-op), defensive (gap raises),
  and architecture invariant (dense path doesn't run MoE helper).

Verified end-to-end on Apple Silicon Metal: Qwen/Qwen3.6-35B-A3B-FP8
generates correctly via the new branch, Qwen/Qwen3.6-35B-A3B (bf16) via
the unchanged combined-format branch (both shims dormant - confirms the
new branch is properly gated). Existing Qwen3.5 golden-token smoke
(test_qwen35_smoke.py) unchanged: 5/5 pass.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
Qwen-org's canonical FP8 release Qwen/Qwen3.6-35B-A3B-FP8 fails strict
load_weights on mlx-lm main with `Received 61690 parameters not in
model`, broken down as two independent issues:

1. Per-expert MoE layout. The FP8 release ships expert MLPs as one
   tensor per expert per projection
   (model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj
   .weight), 30720 = 256 experts x 40 layers x 3 projections. The bf16
   master Qwen/Qwen3.6-35B-A3B is already pre-stacked
   (experts.gate_up_proj / experts.down_proj) and loads via the
   existing combined-format branch unchanged; only the FP8 release
   lands per-expert, likely because Qwen's FP8 quantization pipeline
   runs per-expert.

2. FP8 weight_scale_inv. The release uses Qwen's standard
   quant_method="fp8" with E4M3 weights + bf16 128x128 block scales
   (*_weight_scale_inv tensors) plus optional activation_scale
   tensors. The qwen3_5{,_moe} family had no path for these. The same
   layout is already handled inline in
   models/{deepseek_v3,deepseek_v32,minimax,mimo_v2_flash,ministral3}
   .py (30970 keys: 30720 per-expert FP8 scales + 250 attention/
   lm_head scales).

Together the two bugs account for the 61690-key strict-load failure
on vanilla main.

This PR adds:

- FP8 weight_scale_inv dequant + activation_scale drop in
  qwen3_5.Model.sanitize and qwen3_5_moe.Model.sanitize, mirroring
  the inline pattern in the five sibling FP8-handling model files.
- Per-expert MoE stacking branch in qwen3_5_moe.Model.sanitize, using
  a (scan -> validate -> walk) structure that emits the combined
  switch_mlp.{gate,up,down}_proj.weight form downstream load_weights
  expects. Pre-stacked checkpoints take the original branch
  unchanged.
- Defensive contiguity check raises ValueError on non-contiguous
  expert indices ({0,1,3} skipping 2) so a malformed checkpoint
  fails loud rather than silently dropping experts.

Tests in tests/test_models.py:

- test_qwen3_5_fp8_weight_scale_inv_dequantizes_in_sanitize
  (FP8 dequant via dense Model.sanitize, activation_scale drop)
- test_qwen3_5_moe_fp8_weight_scale_inv_dequantizes_in_sanitize
  (FP8 dequant via MoE Model.sanitize)
- test_qwen3_5_moe_per_expert_weights_stack_to_switch_mlp
  (positive: per-expert input -> stacked switch_mlp.* output)
- test_qwen3_5_moe_per_expert_gap_raises
  (defensive: {0,1,3} raises ValueError)
- test_qwen3_5_moe_combined_format_still_splits_to_switch_mlp
  (regression guard for the original combined-format branch)

Verified end-to-end on Apple Silicon Metal: Qwen/Qwen3.6-35B-A3B-FP8
loads via vanilla mlx-lm (no downstream compat layer) and produces
finite, sane bf16 weights after dequant. Pre-stacked checkpoints
(mlx-community redistributions, Qwen/Qwen3.6-35B-A3B bf16 master,
unsloth MLX variants) load unchanged via the existing combined-format
branch with both shims dormant.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
@sdayal sdayal force-pushed the qwen3_5_moe-per-expert-sanitize branch from cd039ff to eb76d4f Compare April 29, 2026 20:54
@sdayal sdayal changed the title Support per-expert MoE checkpoints in qwen3_5_moe.sanitize Support per-expert MoE checkpoints in qwen3_5_moe.sanitize, plus FP8 dequant Apr 29, 2026
@sdayal sdayal marked this pull request as draft April 30, 2026 05:49
LxYuan0420 pushed a commit to vllm-project/vllm-metal that referenced this pull request Apr 30, 2026
Tracking issue: #289

## Summary

`Qwen/Qwen3.6-35B-A3B-FP8` ships expert MLPs as one tensor per expert
per projection:

```
model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight
```

The bf16 master `Qwen/Qwen3.6-35B-A3B` is already pre-stacked
(`experts.gate_up_proj` / `experts.down_proj`) and loads via the
existing combined-format branch in `mlx_lm.qwen3_5_moe.Model.sanitize`
unchanged — only the FP8 release lands per-expert, likely because Qwen's
FP8 quantization pipeline runs per-expert and the artifact is not
re-stacked.

On vllm-metal `main`, loading `Qwen/Qwen3.6-35B-A3B-FP8` fails strict
`load_weights` with `Received 30720 parameters not in model` — these are
per-expert MoE tensors that vllm-metal's existing FP8 dequant compat
doesn't address. (For reference: the same checkpoint on vanilla mlx-lm
fails with 61,690 keys, the difference being 30,970 FP8
`weight_scale_inv` tensors that vllm-metal's
`compat.py::_dequantize_qwen35_fp8_weights` already handles separately.)

## What this PR does

Add `_stack_qwen36_moe_per_expert_weights` chained after FP8 dequant in
the MoE sanitize wrapper:

1. **Scan** the weights dict for per-layer experts prefixes and their
expert-index sets.
2. **Validate** each prefix's index set is a contiguous `{0, 1, …, N-1}`
(raises `ValueError` otherwise).
3. **Walk** the per-expert tensors in order, `mx.stack` along axis 0,
`mx.concatenate` gate+up along the intermediate-dim axis, emit the
combined `experts.gate_up_proj` / `experts.down_proj` form upstream
sanitize already handles.

Pre-stacked checkpoints are unaffected (helper short-circuits when no
per-expert keys are present).

The MoE-only nature of the stacking is made explicit by splitting the
sanitize patch by model class:

- `mlx_lm.models.qwen3_5.Model` → wrapped with **FP8 dequant only**
(`_transform_dense`)
- `mlx_lm.models.qwen3_5_moe.Model` → wrapped with **FP8 dequant +
per-expert stacking** (`_transform_moe`)

Routing is driven by an explicit `transforms_by_module` map; future Qwen
variants added without a corresponding entry get logged as `unpatchable`
rather than silently inheriting one of the two transforms.

## Why a vllm-metal compat shim instead of waiting for upstream

mlx-community publishes pre-stacked redistributions of this checkpoint
that already load on existing mlx-lm. This shim lets users load
Qwen-org's canonical FP8 artifact directly without a 35GB→70GB bf16
intermediate conversion step that doesn't fit on memory-constrained Macs
(≤64 GB).

This complements ml-explore/mlx-lm#1224, which adds the same per-expert
stacking logic plus FP8 `weight_scale_inv` dequant for the qwen3_5
family inline in upstream sanitize. Once mlx-lm#1224 lands and
vllm-metal's mlx-lm pin bumps past a release containing it, both this
PR's per-expert stacking shim and the existing
`_dequantize_qwen35_fp8_weights` shim in `compat.py` become removable in
a follow-up cleanup.

## Files

- `vllm_metal/compat.py` — add `_stack_qwen36_moe_per_expert_weights`
helper; split sanitize patches into per-class transforms via
`transforms_by_module`.
- `docs/supported_models.md` — update Qwen3.6 row note + link this PR.
- `tests/test_compat.py` — four new unit tests using the existing
numpy-fake-mlx fixture (no real model weights, runs in milliseconds):
- `test_per_expert_moe_tensors_stack_to_combined` — positive: per-expert
input produces correctly stacked combined output, content preserved per
axis-0 slot.
- `test_pre_stacked_moe_is_noop_for_per_expert_helper` — regression:
pre-stacked input passes through unchanged (covers mlx-community
redistributions and Qwen3.6 bf16 master).
- `test_non_contiguous_per_expert_indices_raise` — defensive: malformed
`{0, 1, 3}` checkpoint raises `ValueError`.
- `test_per_expert_helper_does_not_run_on_dense_qwen35` — architecture
invariant: dense path doesn't run the MoE helper.

## Verification (per #289 pass bar)

| Checkpoint | Hardware | Status | Output |
|---|---|---|---|
| `Qwen/Qwen3.6-35B-A3B-FP8` | M3 Max / 128 GB | ✅ loads, generates |
`"The capital of France is"` → `" Paris, a city renowned for its iconic
landmarks such"` |
| `Qwen/Qwen3.6-35B-A3B` (bf16) | M3 Max / 128 GB | ✅ loads via
unchanged combined-format branch (both shims dormant — confirms the new
branch is properly gated) | same output |

- Hybrid SDPA + GDN linear attention path on Apple Silicon Metal, paged
KV cache.
- `pytest tests/test_compat.py`: **15 passed, 1 skipped** (the skip is
the pre-existing `VLLM_METAL_RUN_REAL_MLX_FP8_TESTS=1`-gated test).
- Existing Qwen3.5 golden-token smoke (`test_qwen35_smoke.py`): 5/5
pass, unchanged.
- `bash scripts/lint.sh`: clean (shellcheck, ruff check, ruff format
--check, mypy).

Rebased on latest `main` (3323d32) and re-validated against the bumped
dep stack: `mlx-lm 0.31.3` (from #313), `vllm 0.20.0+cpu` (from #262),
`transformers 5.7.0`.

---------

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
@sdayal sdayal marked this pull request as ready for review April 30, 2026 10:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant