Skip to content

[Qwen3.6] Stack per-expert MoE tensors during mlx_lm sanitize#312

Merged
LxYuan0420 merged 2 commits intovllm-project:mainfrom
sdayal:qwen3_6-moe-per-expert-stacking
Apr 30, 2026
Merged

[Qwen3.6] Stack per-expert MoE tensors during mlx_lm sanitize#312
LxYuan0420 merged 2 commits intovllm-project:mainfrom
sdayal:qwen3_6-moe-per-expert-stacking

Conversation

@sdayal
Copy link
Copy Markdown
Contributor

@sdayal sdayal commented Apr 28, 2026

Tracking issue: #289

Summary

Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs as one tensor per expert per projection:

model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight

The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked (experts.gate_up_proj / experts.down_proj) and loads via the existing combined-format branch in mlx_lm.qwen3_5_moe.Model.sanitize unchanged — only the FP8 release lands per-expert, likely because Qwen's FP8 quantization pipeline runs per-expert and the artifact is not re-stacked.

On vllm-metal main, loading Qwen/Qwen3.6-35B-A3B-FP8 fails strict load_weights with Received 30720 parameters not in model — these are per-expert MoE tensors that vllm-metal's existing FP8 dequant compat doesn't address. (For reference: the same checkpoint on vanilla mlx-lm fails with 61,690 keys, the difference being 30,970 FP8 weight_scale_inv tensors that vllm-metal's compat.py::_dequantize_qwen35_fp8_weights already handles separately.)

What this PR does

Add _stack_qwen36_moe_per_expert_weights chained after FP8 dequant in the MoE sanitize wrapper:

  1. Scan the weights dict for per-layer experts prefixes and their expert-index sets.
  2. Validate each prefix's index set is a contiguous {0, 1, …, N-1} (raises ValueError otherwise).
  3. Walk the per-expert tensors in order, mx.stack along axis 0, mx.concatenate gate+up along the intermediate-dim axis, emit the combined experts.gate_up_proj / experts.down_proj form upstream sanitize already handles.

Pre-stacked checkpoints are unaffected (helper short-circuits when no per-expert keys are present).

The MoE-only nature of the stacking is made explicit by splitting the sanitize patch by model class:

  • mlx_lm.models.qwen3_5.Model → wrapped with FP8 dequant only (_transform_dense)
  • mlx_lm.models.qwen3_5_moe.Model → wrapped with FP8 dequant + per-expert stacking (_transform_moe)

Routing is driven by an explicit transforms_by_module map; future Qwen variants added without a corresponding entry get logged as unpatchable rather than silently inheriting one of the two transforms.

Why a vllm-metal compat shim instead of waiting for upstream

mlx-community publishes pre-stacked redistributions of this checkpoint that already load on existing mlx-lm. This shim lets users load Qwen-org's canonical FP8 artifact directly without a 35GB→70GB bf16 intermediate conversion step that doesn't fit on memory-constrained Macs (≤64 GB).

This complements ml-explore/mlx-lm#1224, which adds the same per-expert stacking logic plus FP8 weight_scale_inv dequant for the qwen3_5 family inline in upstream sanitize. Once mlx-lm#1224 lands and vllm-metal's mlx-lm pin bumps past a release containing it, both this PR's per-expert stacking shim and the existing _dequantize_qwen35_fp8_weights shim in compat.py become removable in a follow-up cleanup.

Files

  • vllm_metal/compat.py — add _stack_qwen36_moe_per_expert_weights helper; split sanitize patches into per-class transforms via transforms_by_module.
  • docs/supported_models.md — update Qwen3.6 row note + link this PR.
  • tests/test_compat.py — four new unit tests using the existing numpy-fake-mlx fixture (no real model weights, runs in milliseconds):
    • test_per_expert_moe_tensors_stack_to_combined — positive: per-expert input produces correctly stacked combined output, content preserved per axis-0 slot.
    • test_pre_stacked_moe_is_noop_for_per_expert_helper — regression: pre-stacked input passes through unchanged (covers mlx-community redistributions and Qwen3.6 bf16 master).
    • test_non_contiguous_per_expert_indices_raise — defensive: malformed {0, 1, 3} checkpoint raises ValueError.
    • test_per_expert_helper_does_not_run_on_dense_qwen35 — architecture invariant: dense path doesn't run the MoE helper.

Verification (per #289 pass bar)

Checkpoint Hardware Status Output
Qwen/Qwen3.6-35B-A3B-FP8 M3 Max / 128 GB ✅ loads, generates "The capital of France is"" Paris, a city renowned for its iconic landmarks such"
Qwen/Qwen3.6-35B-A3B (bf16) M3 Max / 128 GB ✅ loads via unchanged combined-format branch (both shims dormant — confirms the new branch is properly gated) same output
  • Hybrid SDPA + GDN linear attention path on Apple Silicon Metal, paged KV cache.
  • pytest tests/test_compat.py: 15 passed, 1 skipped (the skip is the pre-existing VLLM_METAL_RUN_REAL_MLX_FP8_TESTS=1-gated test).
  • Existing Qwen3.5 golden-token smoke (test_qwen35_smoke.py): 5/5 pass, unchanged.
  • bash scripts/lint.sh: clean (shellcheck, ruff check, ruff format --check, mypy).

Rebased on latest main (3323d32) and re-validated against the bumped dep stack: mlx-lm 0.31.3 (from #313), vllm 0.20.0+cpu (from #262), transformers 5.7.0.

@sdayal sdayal force-pushed the qwen3_6-moe-per-expert-stacking branch 5 times, most recently from 903b1f2 to 89207df Compare April 28, 2026 19:52
Comment thread tests/test_qwen36_smoke.py Outdated
Comment thread vllm_metal/compat.py
Comment thread vllm_metal/compat.py Outdated
Copy link
Copy Markdown
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments.

Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs per-expert
(model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight),
which mlx_lm.qwen3_5_moe.sanitize doesn't recognize. Loading on unpatched
mlx-lm fails strict load_weights with 30720 unexpected keys (256 experts
x 40 layers x 3 projections).

The bf16 master Qwen/Qwen3.6-35B-A3B is already pre-stacked
(experts.gate_up_proj / experts.down_proj) and loads via the existing
combined-format branch unchanged; only the FP8 release lands per-expert,
likely because Qwen's FP8 quantization pipeline runs per-expert.

Add _stack_qwen36_moe_per_expert_weights mirroring the (scan -> validate
-> walk) structure of ml-explore/mlx-lm#1224. Split the sanitize patch
into per-class transforms so the MoE-only nature of the stacking is
self-evident:

- mlx_lm.models.qwen3_5.Model      -> _transform_dense (FP8 dequant only)
- mlx_lm.models.qwen3_5_moe.Model  -> _transform_moe   (FP8 dequant + stack)

mlx-community publishes pre-stacked redistributions, but converting
Qwen-org's FP8 release ourselves needs a 35GB->70GB bf16 intermediate
that doesn't fit on Macs <=64 GB. This shim lets users load the
canonical artifact directly. Removable once vllm-metal's mlx-lm pin
bumps past mlx-lm#1224.

Files:
- vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights helper,
  split sanitize patches by model class via transforms_by_module map.
- docs/supported_models.md: update Qwen3.6 row note + link this PR.
- tests/test_compat.py: 4 new unit tests covering positive (per-expert
  -> combined), regression (pre-stacked no-op), defensive (gap raises),
  and architecture invariant (dense path doesn't run MoE helper).

Verified end-to-end on Apple Silicon Metal: Qwen/Qwen3.6-35B-A3B-FP8
generates correctly via the new branch, Qwen/Qwen3.6-35B-A3B (bf16) via
the unchanged combined-format branch (both shims dormant - confirms the
new branch is properly gated). Existing Qwen3.5 golden-token smoke
(test_qwen35_smoke.py) unchanged: 5/5 pass.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
@sdayal sdayal force-pushed the qwen3_6-moe-per-expert-stacking branch from 89207df to 1a9622e Compare April 29, 2026 12:07
@sdayal
Copy link
Copy Markdown
Contributor Author

sdayal commented Apr 29, 2026

Thanks for the review. Pushed 1a9622e addressing all three:

  • Test surface: removed the opt-in 35B smoke; added four tests/test_compat.py unit tests covering positive / pre-stacked no-op / gap-raises /dense-path-isolation.
  • Helper shape: rewrote to mirror upstream mlx-lm#1224's (scan → validate → walk) structure. Doc comment now points at the upstream PR with an explicit removability marker.
  • Patch ownership: split into _transform_dense (qwen3_5) and _transform_moe (qwen3_5_moe), driven by an explicit transforms_by_module map.

Pre-submit clean: pytest tests/test_compat.py 15 passed / 1 skipped, bash scripts/lint.sh all four stages green. Rebased on latest main (3323d32) and re-validated against the bumped dep stack: mlx-lm 0.31.3 (from #313), vllm 0.20.0+cpu (from #262), transformers 5.7.0.

End-to-end re-verified on Qwen/Qwen3.6-35B-A3B-FP8 (per-expert path fires) and Qwen/Qwen3.6-35B-A3B bf16 master (combined-format path, both shims dormant).

Comment thread vllm_metal/compat.py
Scan for gate_proj/up_proj/down_proj index sets together and raise a
named ValueError on missing-family or mismatched-index cases, instead
of leaking a KeyError from dict.pop during the walk step. Adds a
focused test covering both flavors.

Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
@sdayal sdayal force-pushed the qwen3_6-moe-per-expert-stacking branch from 6caa176 to 9e721c2 Compare April 30, 2026 06:57
Copy link
Copy Markdown
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; TODO: remove it once mlx-lm#1224 lands in the pinned upstream version

@LxYuan0420 LxYuan0420 merged commit 197215d into vllm-project:main Apr 30, 2026
5 checks passed
@sdayal sdayal deleted the qwen3_6-moe-per-expert-stacking branch April 30, 2026 10:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants