Skip to content

Commit 3b43ed7

Browse files
committed
[Qwen3.6] Stack per-expert MoE tensors during mlx_lm sanitize
Qwen-org Qwen3.6 MoE checkpoints (e.g. Qwen/Qwen3.6-35B-A3B-FP8) ship expert MLPs as one tensor per expert per projection: model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight mlx_lm.qwen3_5_moe.Model.sanitize expects the combined-format layout (experts.gate_up_proj / experts.down_proj) produced by mlx_lm.convert and shipped by mlx-community redistributions. Loading a Qwen-org checkpoint fails strict load_weights with thousands of unexpected keys (30720 for the 35B-A3B variant: 256 experts x 40 layers x 3 projections). Extend the existing FP8 sanitize compat shim with a pre-step that detects per-expert MoE tensors, validates the index range is contiguous from 0, and stacks them (mx.stack along axis 0) into the combined experts.gate_up_proj + experts.down_proj form upstream sanitize already handles. Pre-stacked checkpoints are unaffected (helper short-circuits when no per-expert keys are present). This is the downstream complement to ml-explore/mlx-lm#1224, which adds the same stacking logic inline in qwen3_5_moe.Model.sanitize. When that lands and vllm-metal's mlx-lm pin bumps past it, this shim can be removed. Files: - vllm_metal/compat.py: add _stack_qwen36_moe_per_expert_weights helper, chained after FP8 dequant in the patched sanitize. - docs/supported_models.md: update Qwen3.6 row note. - tests/test_qwen36_smoke.py: opt-in e2e smoke gated on QWEN36_MOE_FP8_PATH env var (pytest skips by default; runs in 25s against a local Qwen3.6 MoE FP8 checkpoint). Verified end-to-end on Qwen/Qwen3.6-35B-A3B-FP8: greedy decode of "The capital of France is" returns " Paris, a city renowned for its iconic landmarks such" with the hybrid SDPA + GDN linear attention path on Apple Silicon Metal. Existing Qwen3.5 golden-token smoke (test_qwen35_smoke.py) unchanged: 5/5 pass.
1 parent 992c797 commit 3b43ed7

3 files changed

Lines changed: 152 additions & 2 deletions

File tree

docs/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Metal. Qwen3 is explicitly covered by the paged prefix-cache e2e test.
2626
| --- | --- | --- | --- | --- | --- |
2727
| Qwen3 || GQA (paged) || [#232](https://github.com/vllm-project/vllm-metal/pull/232), [#237](https://github.com/vllm-project/vllm-metal/pull/237), [#283](https://github.com/vllm-project/vllm-metal/pull/283) | Validated by the paged prefix-cache e2e test |
2828
| Qwen3.5 || Hybrid SDPA + GDN linear || [#210](https://github.com/vllm-project/vllm-metal/pull/210), [#226](https://github.com/vllm-project/vllm-metal/pull/226), [#230](https://github.com/vllm-project/vllm-metal/pull/230), [#235](https://github.com/vllm-project/vllm-metal/pull/235), [#239](https://github.com/vllm-project/vllm-metal/pull/239), [#243](https://github.com/vllm-project/vllm-metal/pull/243), [#259](https://github.com/vllm-project/vllm-metal/pull/259), [#265](https://github.com/vllm-project/vllm-metal/pull/265), [#194](https://github.com/vllm-project/vllm-metal/issues/194) | Upstream keeps automatic prefix caching off for hybrid/Mamba models |
29-
| Qwen3.6 || Hybrid SDPA + GDN linear (MoE) || | Upstream keeps automatic prefix caching off for hybrid/Mamba models |
29+
| Qwen3.6 || Hybrid SDPA + GDN linear (MoE) || | Verified on `Qwen/Qwen3.6-35B-A3B-FP8`. Per-expert MoE tensors stacked at sanitize. Upstream keeps automatic prefix caching off for hybrid/Mamba models |
3030
| Qwen3-Next || Hybrid SDPA + GDN linear || [#240](https://github.com/vllm-project/vllm-metal/pull/240) | Upstream keeps automatic prefix caching off for hybrid/Mamba models |
3131
| Gemma 4 | 🔵 | GQA + per-layer sliding window + YOCO || [#251](https://github.com/vllm-project/vllm-metal/pull/251), [#260](https://github.com/vllm-project/vllm-metal/pull/260), [#269](https://github.com/vllm-project/vllm-metal/pull/269), [#275](https://github.com/vllm-project/vllm-metal/pull/275), [#277](https://github.com/vllm-project/vllm-metal/pull/277), [#278](https://github.com/vllm-project/vllm-metal/pull/278), [#282](https://github.com/vllm-project/vllm-metal/pull/282), [#276](https://github.com/vllm-project/vllm-metal/issues/276), [#279](https://github.com/vllm-project/vllm-metal/pull/279), [#281](https://github.com/vllm-project/vllm-metal/issues/281), [#283](https://github.com/vllm-project/vllm-metal/pull/283) | Default-on for non-hybrid paged models; overall model support remains experimental |
3232
| Gemma 3 || GQA (paged) || [#283](https://github.com/vllm-project/vllm-metal/pull/283) | tested on gemma-3-1b-it-qat-4bit; gemma-3-4b-it-4bit verified for text-only generation with VLM image inputs bypassed |

tests/test_qwen36_smoke.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""End-to-end smoke for Qwen3.6 MoE FP8 (Qwen-org per-expert layout).
3+
4+
Exercises the per-expert MoE stacking compat path in
5+
``vllm_metal.compat._stack_qwen36_moe_per_expert_weights`` plus FP8 dequant,
6+
hybrid SDPA + GDN linear attention, and paged KV cache. Skipped unless a local
7+
checkpoint is available, since the smallest Qwen3.6 MoE FP8 weight is ~35 GB
8+
and is not appropriate for CI.
9+
10+
Run with a local checkpoint:
11+
12+
QWEN36_MOE_FP8_PATH=~/models/Qwen3.6-35B-A3B-FP8 \\
13+
VLLM_ENABLE_V1_MULTIPROCESSING=0 \\
14+
python -m pytest tests/test_qwen36_smoke.py -v -s
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import os
20+
from pathlib import Path
21+
22+
import pytest
23+
24+
MODEL_PATH_ENV = "QWEN36_MOE_FP8_PATH"
25+
MAX_TOKENS = 10
26+
PROMPT = "The capital of France is"
27+
28+
29+
def _resolved_model_path() -> Path | None:
30+
raw = os.environ.get(MODEL_PATH_ENV)
31+
if not raw:
32+
return None
33+
path = Path(os.path.expanduser(raw))
34+
return path if path.is_dir() else None
35+
36+
37+
@pytest.fixture(autouse=True, scope="module")
38+
def _set_env():
39+
with pytest.MonkeyPatch.context() as mp:
40+
mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
41+
if "VLLM_METAL_MEMORY_FRACTION" not in os.environ:
42+
mp.setenv("VLLM_METAL_MEMORY_FRACTION", "auto")
43+
yield
44+
45+
46+
@pytest.mark.slow
47+
def test_qwen36_moe_fp8_generates():
48+
model_path = _resolved_model_path()
49+
if model_path is None:
50+
pytest.skip(
51+
f"Set {MODEL_PATH_ENV} to a Qwen3.6 MoE FP8 checkpoint directory to run."
52+
)
53+
54+
from vllm import LLM, SamplingParams
55+
56+
llm = LLM(model=str(model_path), max_model_len=512, max_num_seqs=1)
57+
sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
58+
outputs = llm.generate([PROMPT], sp)
59+
60+
text = outputs[0].outputs[0].text
61+
token_ids = list(outputs[0].outputs[0].token_ids)
62+
63+
print(f"\n prompt: {PROMPT!r}")
64+
print(f" output: {text!r}")
65+
print(f" ids: {token_ids}")
66+
67+
# Loose factual assertion: greedy decode of "The capital of France is" must
68+
# surface "Paris" within the first MAX_TOKENS tokens for any reasonable
69+
# Qwen3.6-A3B variant. Tighter golden IDs would be brittle across mlx
70+
# versions and quant formats.
71+
assert len(token_ids) == MAX_TOKENS, f"expected {MAX_TOKENS} tokens, got {token_ids}"
72+
assert "Paris" in text, f"expected 'Paris' in greedy output, got {text!r}"

vllm_metal/compat.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,79 @@ def _dequantize_qwen35_fp8_weights(
131131
return new_weights
132132

133133

134+
def _stack_qwen36_moe_per_expert_weights(
135+
weights: Mapping[str, Any], mx: Any
136+
) -> Mapping[str, Any]:
137+
"""Combine per-expert MoE tensors into the stacked layout mlx_lm expects.
138+
139+
Qwen-org Qwen3.6 MoE checkpoints (e.g. ``Qwen/Qwen3.6-35B-A3B-FP8``) ship
140+
expert MLPs as one tensor per expert per projection:
141+
``...mlp.experts.{E}.{gate,up,down}_proj.weight``. ``mlx_lm.qwen3_5_moe``'s
142+
``sanitize`` expects them already concatenated as
143+
``...mlp.experts.gate_up_proj`` (gate then up along the intermediate axis)
144+
and ``...mlp.experts.down_proj``, both stacked along axis 0 over experts.
145+
146+
No-op when no per-expert keys are present (dense Qwen3.5/3.6 or already-
147+
stacked MoE checkpoints).
148+
"""
149+
experts_marker = ".mlp.experts."
150+
proj_suffixes = (".gate_proj.weight", ".up_proj.weight", ".down_proj.weight")
151+
groups: dict[str, dict[str, dict[int, Any]]] = {}
152+
consumed: set[str] = set()
153+
for key in weights:
154+
marker_pos = key.find(experts_marker)
155+
if marker_pos == -1:
156+
continue
157+
suffix = next((s for s in proj_suffixes if key.endswith(s)), None)
158+
if suffix is None:
159+
continue
160+
index_start = marker_pos + len(experts_marker)
161+
index_end = len(key) - len(suffix)
162+
index_str = key[index_start:index_end]
163+
if not index_str.isdigit():
164+
continue
165+
prefix = key[: marker_pos + len(".mlp.experts")]
166+
proj = suffix[1:-len(".weight")] # ".gate_proj.weight" -> "gate_proj"
167+
groups.setdefault(prefix, {}).setdefault(proj, {})[int(index_str)] = weights[key]
168+
consumed.add(key)
169+
170+
if not groups:
171+
return weights
172+
173+
logger.debug(
174+
"Stacking per-expert MoE tensors at %d prefixes (%d tensors consumed)",
175+
len(groups),
176+
len(consumed),
177+
)
178+
new_weights = {k: v for k, v in weights.items() if k not in consumed}
179+
for prefix, proj_to_experts in groups.items():
180+
missing = {"gate_proj", "up_proj", "down_proj"} - proj_to_experts.keys()
181+
if missing:
182+
raise ValueError(
183+
f"Incomplete per-expert MoE tensors at {prefix!r}: "
184+
f"missing projections {sorted(missing)}."
185+
)
186+
expert_indices = sorted(proj_to_experts["gate_proj"].keys())
187+
expected = list(range(len(expert_indices)))
188+
if expert_indices != expected:
189+
raise ValueError(
190+
f"Non-contiguous per-expert MoE indices at {prefix!r}: "
191+
f"got {expert_indices[:3]}{expert_indices[-3:]}, "
192+
f"expected 0..{len(expected) - 1}."
193+
)
194+
for proj in ("up_proj", "down_proj"):
195+
if sorted(proj_to_experts[proj].keys()) != expert_indices:
196+
raise ValueError(
197+
f"Per-expert MoE index mismatch at {prefix!r}.{proj}."
198+
)
199+
gate = mx.stack([proj_to_experts["gate_proj"][i] for i in expert_indices])
200+
up = mx.stack([proj_to_experts["up_proj"][i] for i in expert_indices])
201+
down = mx.stack([proj_to_experts["down_proj"][i] for i in expert_indices])
202+
new_weights[f"{prefix}.gate_up_proj"] = mx.concatenate([gate, up], axis=-2)
203+
new_weights[f"{prefix}.down_proj"] = down
204+
return new_weights
205+
206+
134207
def _patch_mlx_lm_qwen35_fp8_sanitize() -> None:
135208
"""Teach mlx_lm's Qwen3.5 loaders to consume local FP8 ``weight_scale_inv``.
136209
@@ -177,11 +250,16 @@ def _patch_mlx_lm_qwen35_fp8_sanitize() -> None:
177250
)
178251
return
179252

253+
def _transform(_self, weights):
254+
weights = _dequantize_qwen35_fp8_weights(weights, mx)
255+
weights = _stack_qwen36_moe_per_expert_weights(weights, mx)
256+
return weights
257+
180258
def _patch_model_sanitize(model_cls) -> bool:
181259
return _wrap_model_sanitize(
182260
model_cls,
183261
"_vllm_metal_qwen35_fp8_patch",
184-
lambda _self, weights: _dequantize_qwen35_fp8_weights(weights, mx),
262+
_transform,
185263
)
186264

187265
patched_modules = []

0 commit comments

Comments
 (0)