[transformers-to-mlx skill] Add bailing_hybrid (Ling-2.6-flash) model#1227
[transformers-to-mlx skill] Add bailing_hybrid (Ling-2.6-flash) model#1227ivanfioravanti wants to merge 6 commits intoml-explore:mainfrom
Conversation
Adds support for inclusionAI/Ling-2.6-flash (BailingMoeV2_5ForCausalLM, model_type=bailing_hybrid): a 104B / 7.4B-active hybrid that mixes MLA with Lightning-style linear attention (1:7 ratio, layers 7/15/23/31 are MLA) and a sigmoid noaux_tc MoE (256 experts, 1 shared, group-limited top-8). - New self-contained model file mlx_lm/models/bailing_hybrid.py - Reuses MLA absorbed-form (embed_q/unembed_out split) from deepseek_v3 - Reuses linear-attention recurrence + GroupRMSNorm + group-limited MoE patterns from bailing_moe_linear; slope formula matches transformers exactly (no clamp on layer_idx-1) - sanitize() drops the trailing MTP layer, splits MLA kv_b_proj, and stacks MoE experts into SwitchGLU - make_cache() returns KVCache for MLA layers and ArraysCache(size=1) for linear-attention layers - Adds bailing_hybrid config to tests/test_models.py::test_all_models Generated with the transformers-to-mlx skill. Tested via 4-bit conversion (208GB -> 55GB @ 4.501 bpw) on an M3 Ultra: dtype check passes, 600-token code and 2930-token long-form generations stay coherent end-to-end. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
| def recurrent_gla( | ||
| q: mx.array, | ||
| k: mx.array, | ||
| v: mx.array, | ||
| g: mx.array, | ||
| scale: float, | ||
| h: Optional[mx.array] = None, | ||
| ) -> Tuple[mx.array, mx.array]: | ||
| L = q.shape[2] | ||
| exp_g = mx.exp(g)[:, None, None].astype(q.dtype) | ||
| q = q * scale | ||
| outputs = [] | ||
| for t in range(L): | ||
| q_t = q[:, :, t : t + 1] | ||
| k_t = k[:, :, t : t + 1] | ||
| v_t = v[:, :, t : t + 1] | ||
| h_up = k_t.transpose(0, 1, 3, 2) @ v_t | ||
| h = h_up if h is None else h * exp_g + h_up | ||
| outputs.append(q_t @ h) | ||
| return mx.concatenate(outputs, axis=2), h |
There was a problem hiding this comment.
We could optimize this with a metal kernel
There was a problem hiding this comment.
To make it even faster! Go for it boss!
There was a problem hiding this comment.
I'm trying to build this and test performance, keep you posted 💪
There was a problem hiding this comment.
Much faster! Thanks @kernelpool I see same pattern in the bailing_moe_linear.py, gonna apply optimization there too and test pre / post for both.
There was a problem hiding this comment.
Performance worst in bailing_moe_linear.py, keeping existing implementation there.
Comparison for bailing_hybrid with 2k context:
- pre metal kernel: 693/79 t/s mem 61.1GB, batch 8 tg 1014/218 tps 77.5 GB
- post metal kernel: t/s 1357/80 t/s 59.8GB, batch 8 tg 1621/226 tps 66.2 GB 🔥
Replace the Python `for t in range(L)` recurrence in `recurrent_gla` with a Metal kernel that fuses the entire gated linear-attention loop into a single GPU dispatch (modeled after `gated_delta` and `rwkv7`). Fall back to a compiled-step ops loop when Metal is unavailable or when head dims don't fit the kernel's constraints (`D % 32 == 0`, `Dv % 4 == 0`, `Dk == Dv`). Microbench (B=1, H=32, T=512, D=128): ~13.5x faster than the loop (12.2 ms -> 0.9 ms). Kernel vs. ops outputs agree to ~1.4e-6. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
|
There are still some issues with longer prompts. I'm investigating. |
Made-with: Cursor
fix: snapshot bailing hybrid cache offsets
|
Fixed! Testing with OpenCode and pi mono raised an issue. Cache offset logic was wrong. |
The recurrent_gla Metal kernel does not implement vjp, so backprop through it fails with "Primitive::vjp Not implemented for CustomKernel" when running DWQ (or any mx.value_and_grad path). Add a use_kernel flag to recurrent_gla and pass `not self.training` from LinearAttention so the differentiable ops fallback runs during training while inference still uses the fast kernel. Matches the pattern used in qwen3_next, kimi_linear, and qwen3_5.
|
|
||
| slopes = mx.array(slopes, dtype=mx.float32) | ||
| denom = max(1, self.num_hidden_layers - 1) | ||
| layer_factor = 1 - (self.layer_idx - 1) / denom + 1e-5 |
There was a problem hiding this comment.
| layer_factor = 1 - (self.layer_idx - 1) / denom + 1e-5 | |
| layer_factor = 1 - self.layer_idx / denom + 1e-5 |
This change matches vllm/sglang (the HF reference looks wrong) and improves perplexity in my testing.
| Formula | Perplexity (calibration_v5, 50 × 1024 tokens) |
|---|---|
1 - layer_idx / (N-1) + 1e-5 |
5.959 ± 0.013 |
1 - (layer_idx - 1) / (N-1) + 1e-5 |
6.021 ± 0.013 |
There was a problem hiding this comment.
Confirmed! 5.54 vs 5.6 4bit-gs32 in my tests.
You rock @kernelpool 🚀
Updated the layer factor calculation as suggested by @kernelpool to mach vllm/sglang (HF reference looks wrong)
|
Starting some evals! This architecture looks amazing from performance point of view! |
|
Ready for you super @angeloskath 🚀 |
Adds support for inclusionAI/Ling-2.6-flash (BailingMoeV2_5ForCausalLM, model_type=bailing_hybrid): a 104B / 7.4B-active hybrid that mixes MLA with Lightning-style linear attention (1:7 ratio, layers 7/15/23/31 are MLA) and a sigmoid noaux_tc MoE (256 experts, 1 shared, group-limited top-8).
Generated with the transformers-to-mlx skill. Tested via 4-bit conversion (208GB -> 55GB @ 4.501 bpw) on an M3 Ultra: dtype check passes, 600-token code and 2930-token long-form generations stay coherent end-to-end.