Skip to content

Bug report — Qwen3-Next hybrid cache silently fails - Prompt cache silently broken for Qwen3-Next (hybrid Gated-DeltaNet + SWA Gated-Attention) — recomputes every turn #1162

@bmaxim77

Description

@bmaxim77

Summary

The PromptCache in mlx-lm does not correctly persist state for Qwen3-Next-80B-A3B (hybrid architecture: 36 Gated-Attention layers with SWA + 12 Gated-DeltaNet recurrent layers + MoE routing). Cache hit-rate appears ~0 on multi-turn workloads: every repeat of the exact same prompt triggers a full prefill. No error is raised — the cache silently recomputes.

By contrast, the same model served via llama.cpp's llama-server (post PR #19408, merged Feb 2026) achieves correct cache reuse with 40×+ speedup on warm turns.

Environment

  • Mac Studio M4 Ultra, 128 GB unified memory, macOS 25.3
  • mlx-lm served via LM Studio (MLX backend, build from 2026-04)
  • Model: qwen3-next-80b-a3b-instruct-mlx@8bit (80.8 GiB)
  • Test harness: direct streaming HTTP POST to /v1/chat/completions, measuring TTFT

Expected behavior

On a repeat of the exact same prompt within the loaded model's lifetime, TTFT should drop from cold (~70 s for 60K tokens) to hit (~1–2 s).

Actual behavior

TTFT stays at cold levels for every turn. There is no speedup from cache reuse.

Numbers — cold TTFT (no cache needed — first turn)

Input size TTFT
12K 12.4 s
30K 28.7 s
60K 71.6 s

Numbers — "hot" TTFT (same prompt, repeated) — should be 1–2 s

Input size TTFT observed
12K ~12 s — no change from cold
30K ~28 s — no change from cold
60K ~70 s — no change from cold

Comparison — same model, llama.cpp backend (cache works)

Input size Cold TTFT Hot TTFT Speedup
60K 54 s 1.0 s 54×
120K 250 s ~1.2 s ~208×

Config: --swa-full --ctx-checkpoints 128 --checkpoint-every-n-tokens 4096 --slot-prompt-similarity 0.3

Reproduction

  1. Load qwen3-next-80b-a3b-instruct-mlx@8bit in LM Studio (or mlx-lm directly)
  2. Send a ~10K token prompt via /v1/chat/completions, record TTFT
  3. Send the exact same prompt again, record TTFT
  4. Observe: both TTFTs are similar (both cold) — cache is not hit

Root cause hypothesis

PromptCache appears to be designed for pure transformer layers (list of K, V per layer). Qwen3-Next has 12 Gated-DeltaNet layers that hold recurrent state, not K/V. When the cache "restores" a prior prompt, those 12 layers presumably reset to zero state, so the effective model output is wrong → the server falls back to recomputing the full prompt.

Related: the SWA (sliding-window attention) layers may also need special handling around window boundaries that the current cache class doesn't implement.

Reference implementation

llama.cpp PR #19408 (merged Feb 2026) added hybrid-architecture checkpoint save/restore. The key additions:

  • Per-layer polymorphic cache entry type (K/V for attention, recurrent state for Mamba/DeltaNet)
  • Position-indexed checkpoints (--checkpoint-every-n-tokens)
  • Prefix-match restore that replays only new tokens

A Python port of the same logic into mlx-lm should be feasible in ~1–2 engineer-weeks.

Impact

Qwen3-Next is a 2025 flagship MoE release and a primary target for local agentic workloads (3.9B active of 80B total — fits comfortably on 128 GB consumer machines). Multi-turn / conversational / tool-calling workloads lose the primary performance lever (prefix-cache reuse).

Until this is fixed, users running Qwen3-Next on Apple silicon are better served by the llama.cpp GGUF path, even though MLX cold prefill is ~15 % faster. The cache deficit dominates after turn 2.

Related

  • Existing issue: mlx-lm#980 (speculative decoding breaks prompt cache — similar cache-path fragility)
  • Upstream reference: ggml-org/llama.cpp#19408

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions