Skip to content

Commit 197215d

Browse files
authored
[Qwen3.6] Stack per-expert MoE tensors during mlx_lm sanitize (#312)
Tracking issue: #289 ## Summary `Qwen/Qwen3.6-35B-A3B-FP8` ships expert MLPs as one tensor per expert per projection: ``` model.language_model.layers.{L}.mlp.experts.{E}.{gate,up,down}_proj.weight ``` The bf16 master `Qwen/Qwen3.6-35B-A3B` is already pre-stacked (`experts.gate_up_proj` / `experts.down_proj`) and loads via the existing combined-format branch in `mlx_lm.qwen3_5_moe.Model.sanitize` unchanged — only the FP8 release lands per-expert, likely because Qwen's FP8 quantization pipeline runs per-expert and the artifact is not re-stacked. On vllm-metal `main`, loading `Qwen/Qwen3.6-35B-A3B-FP8` fails strict `load_weights` with `Received 30720 parameters not in model` — these are per-expert MoE tensors that vllm-metal's existing FP8 dequant compat doesn't address. (For reference: the same checkpoint on vanilla mlx-lm fails with 61,690 keys, the difference being 30,970 FP8 `weight_scale_inv` tensors that vllm-metal's `compat.py::_dequantize_qwen35_fp8_weights` already handles separately.) ## What this PR does Add `_stack_qwen36_moe_per_expert_weights` chained after FP8 dequant in the MoE sanitize wrapper: 1. **Scan** the weights dict for per-layer experts prefixes and their expert-index sets. 2. **Validate** each prefix's index set is a contiguous `{0, 1, …, N-1}` (raises `ValueError` otherwise). 3. **Walk** the per-expert tensors in order, `mx.stack` along axis 0, `mx.concatenate` gate+up along the intermediate-dim axis, emit the combined `experts.gate_up_proj` / `experts.down_proj` form upstream sanitize already handles. Pre-stacked checkpoints are unaffected (helper short-circuits when no per-expert keys are present). The MoE-only nature of the stacking is made explicit by splitting the sanitize patch by model class: - `mlx_lm.models.qwen3_5.Model` → wrapped with **FP8 dequant only** (`_transform_dense`) - `mlx_lm.models.qwen3_5_moe.Model` → wrapped with **FP8 dequant + per-expert stacking** (`_transform_moe`) Routing is driven by an explicit `transforms_by_module` map; future Qwen variants added without a corresponding entry get logged as `unpatchable` rather than silently inheriting one of the two transforms. ## Why a vllm-metal compat shim instead of waiting for upstream mlx-community publishes pre-stacked redistributions of this checkpoint that already load on existing mlx-lm. This shim lets users load Qwen-org's canonical FP8 artifact directly without a 35GB→70GB bf16 intermediate conversion step that doesn't fit on memory-constrained Macs (≤64 GB). This complements ml-explore/mlx-lm#1224, which adds the same per-expert stacking logic plus FP8 `weight_scale_inv` dequant for the qwen3_5 family inline in upstream sanitize. Once mlx-lm#1224 lands and vllm-metal's mlx-lm pin bumps past a release containing it, both this PR's per-expert stacking shim and the existing `_dequantize_qwen35_fp8_weights` shim in `compat.py` become removable in a follow-up cleanup. ## Files - `vllm_metal/compat.py` — add `_stack_qwen36_moe_per_expert_weights` helper; split sanitize patches into per-class transforms via `transforms_by_module`. - `docs/supported_models.md` — update Qwen3.6 row note + link this PR. - `tests/test_compat.py` — four new unit tests using the existing numpy-fake-mlx fixture (no real model weights, runs in milliseconds): - `test_per_expert_moe_tensors_stack_to_combined` — positive: per-expert input produces correctly stacked combined output, content preserved per axis-0 slot. - `test_pre_stacked_moe_is_noop_for_per_expert_helper` — regression: pre-stacked input passes through unchanged (covers mlx-community redistributions and Qwen3.6 bf16 master). - `test_non_contiguous_per_expert_indices_raise` — defensive: malformed `{0, 1, 3}` checkpoint raises `ValueError`. - `test_per_expert_helper_does_not_run_on_dense_qwen35` — architecture invariant: dense path doesn't run the MoE helper. ## Verification (per #289 pass bar) | Checkpoint | Hardware | Status | Output | |---|---|---|---| | `Qwen/Qwen3.6-35B-A3B-FP8` | M3 Max / 128 GB | ✅ loads, generates | `"The capital of France is"` → `" Paris, a city renowned for its iconic landmarks such"` | | `Qwen/Qwen3.6-35B-A3B` (bf16) | M3 Max / 128 GB | ✅ loads via unchanged combined-format branch (both shims dormant — confirms the new branch is properly gated) | same output | - Hybrid SDPA + GDN linear attention path on Apple Silicon Metal, paged KV cache. - `pytest tests/test_compat.py`: **15 passed, 1 skipped** (the skip is the pre-existing `VLLM_METAL_RUN_REAL_MLX_FP8_TESTS=1`-gated test). - Existing Qwen3.5 golden-token smoke (`test_qwen35_smoke.py`): 5/5 pass, unchanged. - `bash scripts/lint.sh`: clean (shellcheck, ruff check, ruff format --check, mypy). Rebased on latest `main` (3323d32) and re-validated against the bumped dep stack: `mlx-lm 0.31.3` (from #313), `vllm 0.20.0+cpu` (from #262), `transformers 5.7.0`. --------- Signed-off-by: Shivendra Dayal <sdayal@gmail.com>
1 parent ac94ebb commit 197215d

3 files changed

Lines changed: 283 additions & 10 deletions

File tree

docs/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Metal. Qwen3 is explicitly covered by the paged prefix-cache e2e test.
2626
| --- | --- | --- | --- | --- | --- |
2727
| Qwen3 || GQA (paged) || [#232](https://github.com/vllm-project/vllm-metal/pull/232), [#237](https://github.com/vllm-project/vllm-metal/pull/237), [#283](https://github.com/vllm-project/vllm-metal/pull/283) | Validated by the paged prefix-cache e2e test |
2828
| Qwen3.5 || Hybrid SDPA + GDN linear || [#210](https://github.com/vllm-project/vllm-metal/pull/210), [#226](https://github.com/vllm-project/vllm-metal/pull/226), [#230](https://github.com/vllm-project/vllm-metal/pull/230), [#235](https://github.com/vllm-project/vllm-metal/pull/235), [#239](https://github.com/vllm-project/vllm-metal/pull/239), [#243](https://github.com/vllm-project/vllm-metal/pull/243), [#259](https://github.com/vllm-project/vllm-metal/pull/259), [#265](https://github.com/vllm-project/vllm-metal/pull/265), [#194](https://github.com/vllm-project/vllm-metal/issues/194) | Upstream keeps automatic prefix caching off for hybrid/Mamba models |
29-
| Qwen3.6 || Hybrid SDPA + GDN linear (MoE) || | Upstream keeps automatic prefix caching off for hybrid/Mamba models |
29+
| Qwen3.6 || Hybrid SDPA + GDN linear (MoE) || [#312](https://github.com/vllm-project/vllm-metal/pull/312) | Verified on `Qwen/Qwen3.6-35B-A3B-FP8`. Per-expert MoE tensors stacked at sanitize. Upstream keeps automatic prefix caching off for hybrid/Mamba models |
3030
| Qwen3-Next || Hybrid SDPA + GDN linear || [#240](https://github.com/vllm-project/vllm-metal/pull/240) | Upstream keeps automatic prefix caching off for hybrid/Mamba models |
3131
| Gemma 4 | 🔵 | GQA + per-layer sliding window + YOCO || [#251](https://github.com/vllm-project/vllm-metal/pull/251), [#260](https://github.com/vllm-project/vllm-metal/pull/260), [#269](https://github.com/vllm-project/vllm-metal/pull/269), [#275](https://github.com/vllm-project/vllm-metal/pull/275), [#277](https://github.com/vllm-project/vllm-metal/pull/277), [#278](https://github.com/vllm-project/vllm-metal/pull/278), [#282](https://github.com/vllm-project/vllm-metal/pull/282), [#276](https://github.com/vllm-project/vllm-metal/issues/276), [#279](https://github.com/vllm-project/vllm-metal/pull/279), [#281](https://github.com/vllm-project/vllm-metal/issues/281), [#283](https://github.com/vllm-project/vllm-metal/pull/283) | Default-on for non-hybrid paged models; overall model support remains experimental |
3232
| Gemma 3 || GQA (paged) || [#283](https://github.com/vllm-project/vllm-metal/pull/283) | tested on gemma-3-1b-it-qat-4bit; gemma-3-4b-it-4bit verified for text-only generation with VLM image inputs bypassed |

tests/test_compat.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def _install_fake_qwen35_modules(monkeypatch, *, include_moe: bool):
2121
mlx_core.bfloat16 = np.float32
2222
mlx_core.from_fp8 = lambda weight, dtype=None: np.asarray(weight, dtype=np.float32)
2323
mlx_core.pad = lambda weight, pad_width: np.pad(weight, pad_width)
24+
mlx_core.stack = lambda arrays, axis=0: np.stack(arrays, axis=axis)
25+
mlx_core.concatenate = lambda arrays, axis=0: np.concatenate(arrays, axis=axis)
2426
mlx_pkg.core = mlx_core
2527
monkeypatch.setitem(sys.modules, "mlx", mlx_pkg)
2628
monkeypatch.setitem(sys.modules, "mlx.core", mlx_core)
@@ -175,6 +177,161 @@ def test_patches_higher_rank_weights_for_moe(self, monkeypatch) -> None:
175177
assert f"{gate_up_proj_prefix}.activation_scale" not in sanitized
176178
assert sanitized[f"{gate_up_proj_prefix}.weight"].shape == (2, 256, 128)
177179

180+
def test_per_expert_moe_tensors_stack_to_combined(self, monkeypatch) -> None:
181+
# Qwen/Qwen3.6-35B-A3B-FP8 ships expert MLPs per-expert. The MoE
182+
# sanitize wrapper must stack them along axis 0 and concatenate
183+
# gate+up along the intermediate-dim axis, producing the combined
184+
# form upstream sanitize already handles.
185+
_, moe_module = _install_fake_qwen35_modules(monkeypatch, include_moe=True)
186+
prefix = "model.language_model.layers.0.mlp.experts"
187+
188+
compat._patch_mlx_lm_qwen35_fp8_sanitize()
189+
190+
per_expert = {
191+
f"{prefix}.0.gate_proj.weight": np.full((6, 4), 1.0),
192+
f"{prefix}.0.up_proj.weight": np.full((6, 4), 2.0),
193+
f"{prefix}.0.down_proj.weight": np.full((4, 6), 3.0),
194+
f"{prefix}.1.gate_proj.weight": np.full((6, 4), 4.0),
195+
f"{prefix}.1.up_proj.weight": np.full((6, 4), 5.0),
196+
f"{prefix}.1.down_proj.weight": np.full((4, 6), 6.0),
197+
}
198+
sanitized = moe_module.Model().sanitize(per_expert)
199+
200+
gate_up_key = f"{prefix}.gate_up_proj"
201+
down_key = f"{prefix}.down_proj"
202+
assert gate_up_key in sanitized
203+
assert down_key in sanitized
204+
# gate_up: (num_experts, 2*intermediate, hidden); down: (num_experts, hidden, intermediate)
205+
assert sanitized[gate_up_key].shape == (2, 12, 4)
206+
assert sanitized[down_key].shape == (2, 4, 6)
207+
# Per-expert keys must not leak through after stacking.
208+
assert all(".experts.0." not in k for k in sanitized)
209+
assert all(".experts.1." not in k for k in sanitized)
210+
# Stacking preserves per-expert content along axis 0; gate occupies
211+
# the first half of axis -2, up occupies the second half.
212+
np.testing.assert_array_equal(
213+
sanitized[gate_up_key][0, :6, :], np.full((6, 4), 1.0)
214+
)
215+
np.testing.assert_array_equal(
216+
sanitized[gate_up_key][0, 6:, :], np.full((6, 4), 2.0)
217+
)
218+
np.testing.assert_array_equal(
219+
sanitized[down_key][1, :, :], np.full((4, 6), 6.0)
220+
)
221+
222+
def test_pre_stacked_moe_is_noop_for_per_expert_helper(self, monkeypatch) -> None:
223+
# Pre-stacked checkpoints (mlx-community redistributions, Qwen3.6 bf16
224+
# master) ship `experts.gate_up_proj` / `experts.down_proj` already
225+
# combined. The per-expert helper must short-circuit and pass them
226+
# through unchanged, leaving the combined-format branch in upstream
227+
# sanitize free to do its split.
228+
_, moe_module = _install_fake_qwen35_modules(monkeypatch, include_moe=True)
229+
prefix = "model.language_model.layers.0.mlp.experts"
230+
231+
compat._patch_mlx_lm_qwen35_fp8_sanitize()
232+
233+
gate_up = np.arange(2 * 12 * 4, dtype=np.float32).reshape(2, 12, 4)
234+
down = np.arange(2 * 4 * 6, dtype=np.float32).reshape(2, 4, 6)
235+
weights = {
236+
f"{prefix}.gate_up_proj": gate_up,
237+
f"{prefix}.down_proj": down,
238+
}
239+
sanitized = moe_module.Model().sanitize(weights)
240+
241+
# Helper is a no-op: combined keys present unchanged, no per-expert
242+
# keys appear.
243+
np.testing.assert_array_equal(sanitized[f"{prefix}.gate_up_proj"], gate_up)
244+
np.testing.assert_array_equal(sanitized[f"{prefix}.down_proj"], down)
245+
assert not any(f"{prefix}.0." in k for k in sanitized)
246+
247+
def test_non_contiguous_per_expert_indices_raise(self, monkeypatch) -> None:
248+
# Defensive: a malformed checkpoint shipping experts {0, 1, 3} (skipping
249+
# 2) would silently drop expert 3 if the stacker walked indices in
250+
# order. Helper must raise loudly so the user diagnoses the missing
251+
# tensor instead of getting subtly wrong output.
252+
_, moe_module = _install_fake_qwen35_modules(monkeypatch, include_moe=True)
253+
prefix = "model.language_model.layers.0.mlp.experts"
254+
255+
compat._patch_mlx_lm_qwen35_fp8_sanitize()
256+
257+
gapped = {
258+
f"{prefix}.0.gate_proj.weight": np.zeros((6, 4)),
259+
f"{prefix}.0.up_proj.weight": np.zeros((6, 4)),
260+
f"{prefix}.0.down_proj.weight": np.zeros((4, 6)),
261+
f"{prefix}.1.gate_proj.weight": np.zeros((6, 4)),
262+
f"{prefix}.1.up_proj.weight": np.zeros((6, 4)),
263+
f"{prefix}.1.down_proj.weight": np.zeros((4, 6)),
264+
f"{prefix}.3.gate_proj.weight": np.zeros((6, 4)),
265+
f"{prefix}.3.up_proj.weight": np.zeros((6, 4)),
266+
f"{prefix}.3.down_proj.weight": np.zeros((4, 6)),
267+
}
268+
269+
with pytest.raises(ValueError, match="non-contiguous"):
270+
moe_module.Model().sanitize(gapped)
271+
272+
def test_missing_projection_family_raises(self, monkeypatch) -> None:
273+
# Defensive: a malformed checkpoint missing one entire projection
274+
# family (e.g., no down_proj at all) must surface as a clear
275+
# ValueError naming the missing family, rather than a raw KeyError
276+
# leaking from the walk step. The same path also covers the case
277+
# where only some experts have a given projection (mismatched index
278+
# sets across families).
279+
_, moe_module = _install_fake_qwen35_modules(monkeypatch, include_moe=True)
280+
prefix = "model.language_model.layers.0.mlp.experts"
281+
282+
compat._patch_mlx_lm_qwen35_fp8_sanitize()
283+
284+
# 1) Entire down_proj family absent.
285+
no_down = {
286+
f"{prefix}.0.gate_proj.weight": np.zeros((6, 4)),
287+
f"{prefix}.0.up_proj.weight": np.zeros((6, 4)),
288+
f"{prefix}.1.gate_proj.weight": np.zeros((6, 4)),
289+
f"{prefix}.1.up_proj.weight": np.zeros((6, 4)),
290+
}
291+
with pytest.raises(ValueError, match="missing projection families"):
292+
moe_module.Model().sanitize(no_down)
293+
294+
# 2) down_proj missing for one expert (mismatched index sets).
295+
partial_down = {
296+
f"{prefix}.0.gate_proj.weight": np.zeros((6, 4)),
297+
f"{prefix}.0.up_proj.weight": np.zeros((6, 4)),
298+
f"{prefix}.0.down_proj.weight": np.zeros((4, 6)),
299+
f"{prefix}.1.gate_proj.weight": np.zeros((6, 4)),
300+
f"{prefix}.1.up_proj.weight": np.zeros((6, 4)),
301+
# missing f"{prefix}.1.down_proj.weight"
302+
}
303+
with pytest.raises(ValueError, match="mismatched down_proj"):
304+
moe_module.Model().sanitize(partial_down)
305+
306+
def test_per_expert_helper_does_not_run_on_dense_qwen35(self, monkeypatch) -> None:
307+
# The dense qwen3_5 patch wraps sanitize with FP8 dequant only — the
308+
# per-expert stacking helper must NOT run on dense Qwen3.5/3.6
309+
# checkpoints (no expert tensors exist in dense models, so even an
310+
# accidental call would be a no-op, but the patch architecture
311+
# makes the MoE-only nature explicit).
312+
dense_module, _ = _install_fake_qwen35_modules(monkeypatch, include_moe=True)
313+
314+
compat._patch_mlx_lm_qwen35_fp8_sanitize()
315+
316+
# Dense weights with FP8 quant; no expert tensors anywhere.
317+
sanitized = dense_module.Model().sanitize(
318+
{
319+
"model.language_model.layers.0.self_attn.q_proj.weight": np.ones(
320+
(128, 128)
321+
),
322+
"model.language_model.layers.0.self_attn.q_proj.weight_scale_inv": np.ones(
323+
(1, 1)
324+
),
325+
}
326+
)
327+
assert (
328+
"model.language_model.layers.0.self_attn.q_proj.weight_scale_inv"
329+
not in sanitized
330+
)
331+
assert sanitized[
332+
"model.language_model.layers.0.self_attn.q_proj.weight"
333+
].shape == (128, 128)
334+
178335

179336
def _install_fake_gemma4_text_module(
180337
monkeypatch,

vllm_metal/compat.py

Lines changed: 125 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,101 @@ def _dequantize_qwen35_fp8_weights(
131131
return new_weights
132132

133133

134+
def _stack_qwen36_moe_per_expert_weights(
135+
weights: Mapping[str, Any], mx: Any
136+
) -> Mapping[str, Any]:
137+
"""Combine per-expert MoE tensors into the stacked layout mlx_lm expects.
138+
139+
``Qwen/Qwen3.6-35B-A3B-FP8`` ships expert MLPs as one tensor per expert
140+
per projection: ``...mlp.experts.{E}.{gate,up,down}_proj.weight``. The
141+
bf16 master ``Qwen/Qwen3.6-35B-A3B`` is already pre-stacked and falls
142+
through to the existing combined-format branch in
143+
``mlx_lm.qwen3_5_moe.sanitize`` unchanged. ``mlx_lm.qwen3_5_moe``'s
144+
``sanitize`` expects experts concatenated as
145+
``...mlp.experts.gate_up_proj`` (gate then up along the intermediate axis)
146+
and ``...mlp.experts.down_proj``, both stacked along axis 0 over experts.
147+
148+
Mirrors the (scan -> validate -> walk) structure of upstream
149+
ml-explore/mlx-lm#1224. Removable once vllm-metal's mlx-lm pin bumps
150+
past that merge.
151+
152+
No-op when no per-expert keys are present (dense Qwen3.5/3.6 or already-
153+
stacked MoE checkpoints).
154+
"""
155+
experts_marker = ".mlp.experts."
156+
proj_suffixes = (".gate_proj.weight", ".up_proj.weight", ".down_proj.weight")
157+
# Scan: discover per-layer experts prefixes and per-projection index sets
158+
# for all three projection families, so a checkpoint missing one family
159+
# (or with a mismatched index set across families) fails validation
160+
# cleanly instead of leaking a KeyError during the walk.
161+
layer_proj_indices: dict[str, dict[str, set[int]]] = {}
162+
for key in weights:
163+
marker_pos = key.find(experts_marker)
164+
if marker_pos == -1:
165+
continue
166+
suffix = next((s for s in proj_suffixes if key.endswith(s)), None)
167+
if suffix is None:
168+
continue
169+
index_start = marker_pos + len(experts_marker)
170+
index_end = len(key) - len(suffix)
171+
tail = key[index_start:index_end]
172+
if not tail.isdigit():
173+
continue
174+
prefix = key[: marker_pos + len(".mlp.experts")]
175+
proj = suffix[1 : -len(".weight")] # ".gate_proj.weight" -> "gate_proj"
176+
layer_proj_indices.setdefault(prefix, {}).setdefault(proj, set()).add(int(tail))
177+
178+
if not layer_proj_indices:
179+
return weights
180+
181+
logger.debug(
182+
"Stacking per-expert MoE tensors at %d prefixes",
183+
len(layer_proj_indices),
184+
)
185+
required_projs = ("gate_proj", "up_proj", "down_proj")
186+
new_weights = dict(weights)
187+
for prefix, proj_to_indices in layer_proj_indices.items():
188+
# Validate: every prefix must have all three projection families, and
189+
# all three must share the same contiguous {0..N-1} index set.
190+
missing_projs = [p for p in required_projs if p not in proj_to_indices]
191+
if missing_projs:
192+
raise ValueError(
193+
f"Per-expert MoE weights at {prefix!r} are missing "
194+
f"projection families: {missing_projs}."
195+
)
196+
gate_indices = proj_to_indices["gate_proj"]
197+
expected = set(range(len(gate_indices)))
198+
if gate_indices != expected:
199+
missing = sorted(expected - gate_indices)
200+
extra = sorted(gate_indices - expected)
201+
raise ValueError(
202+
f"Per-expert MoE weights at {prefix!r} have "
203+
f"non-contiguous gate_proj indices: missing={missing}, "
204+
f"unexpected={extra}."
205+
)
206+
for proj in ("up_proj", "down_proj"):
207+
if proj_to_indices[proj] != gate_indices:
208+
missing = sorted(gate_indices - proj_to_indices[proj])
209+
extra = sorted(proj_to_indices[proj] - gate_indices)
210+
raise ValueError(
211+
f"Per-expert MoE weights at {prefix!r} have "
212+
f"mismatched {proj} indices vs gate_proj: "
213+
f"missing={missing}, unexpected={extra}."
214+
)
215+
# Walk: pop per-expert tensors in order, stack, and emit the combined
216+
# form upstream sanitize already handles.
217+
gates, ups, downs = [], [], []
218+
for e in range(len(gate_indices)):
219+
gates.append(new_weights.pop(f"{prefix}.{e}.gate_proj.weight"))
220+
ups.append(new_weights.pop(f"{prefix}.{e}.up_proj.weight"))
221+
downs.append(new_weights.pop(f"{prefix}.{e}.down_proj.weight"))
222+
new_weights[f"{prefix}.gate_up_proj"] = mx.concatenate(
223+
[mx.stack(gates), mx.stack(ups)], axis=-2
224+
)
225+
new_weights[f"{prefix}.down_proj"] = mx.stack(downs)
226+
return new_weights
227+
228+
134229
def _patch_mlx_lm_qwen35_fp8_sanitize() -> None:
135230
"""Teach mlx_lm's Qwen3.5 loaders to consume local FP8 ``weight_scale_inv``.
136231
@@ -177,22 +272,43 @@ def _patch_mlx_lm_qwen35_fp8_sanitize() -> None:
177272
)
178273
return
179274

180-
def _patch_model_sanitize(model_cls) -> bool:
181-
return _wrap_model_sanitize(
182-
model_cls,
183-
"_vllm_metal_qwen35_fp8_patch",
184-
lambda _self, weights: _dequantize_qwen35_fp8_weights(weights, mx),
185-
)
275+
# qwen3_5 (dense) checkpoints only need FP8 dequant — they have no expert
276+
# tensors to stack. Keep the dense patch narrow.
277+
def _transform_dense(_self, weights):
278+
return _dequantize_qwen35_fp8_weights(weights, mx)
279+
280+
# qwen3_5_moe (Qwen-org Qwen3.6-MoE FP8) needs FP8 dequant followed by
281+
# per-expert stacking. The stacking step is the temporary downstream
282+
# complement to ml-explore/mlx-lm#1224 and short-circuits when no
283+
# per-expert keys are present.
284+
def _transform_moe(_self, weights):
285+
weights = _dequantize_qwen35_fp8_weights(weights, mx)
286+
weights = _stack_qwen36_moe_per_expert_weights(weights, mx)
287+
return weights
288+
289+
transforms_by_module: dict[str, Any] = {
290+
"mlx_lm.models.qwen3_5": _transform_dense,
291+
"mlx_lm.models.qwen3_5_moe": _transform_moe,
292+
}
186293

187294
patched_modules = []
188295
unpatchable_modules = []
189296
for module in model_modules:
297+
short_name = module.__name__.rsplit(".", maxsplit=1)[-1]
190298
model_cls = getattr(module, "Model", None)
191299
if model_cls is None:
192-
unpatchable_modules.append(module.__name__.rsplit(".", maxsplit=1)[-1])
300+
unpatchable_modules.append(short_name)
193301
continue
194-
if _patch_model_sanitize(model_cls):
195-
patched_modules.append(module.__name__.rsplit(".", maxsplit=1)[-1])
302+
transform = transforms_by_module.get(module.__name__)
303+
if transform is None:
304+
unpatchable_modules.append(short_name)
305+
continue
306+
if _wrap_model_sanitize(
307+
model_cls,
308+
"_vllm_metal_qwen35_fp8_patch",
309+
transform,
310+
):
311+
patched_modules.append(short_name)
196312
if patched_modules:
197313
logger.debug(
198314
"Patched mlx_lm %s FP8 sanitize compatibility",

0 commit comments

Comments
 (0)