Skip to content

[transformers-to-mlx skill] Add bailing_hybrid (Ling-2.6-flash) model#1227

Open
ivanfioravanti wants to merge 6 commits intoml-explore:mainfrom
ivanfioravanti:add-ling-2.6-flash
Open

[transformers-to-mlx skill] Add bailing_hybrid (Ling-2.6-flash) model#1227
ivanfioravanti wants to merge 6 commits intoml-explore:mainfrom
ivanfioravanti:add-ling-2.6-flash

Conversation

@ivanfioravanti
Copy link
Copy Markdown
Contributor

Adds support for inclusionAI/Ling-2.6-flash (BailingMoeV2_5ForCausalLM, model_type=bailing_hybrid): a 104B / 7.4B-active hybrid that mixes MLA with Lightning-style linear attention (1:7 ratio, layers 7/15/23/31 are MLA) and a sigmoid noaux_tc MoE (256 experts, 1 shared, group-limited top-8).

  • New self-contained model file mlx_lm/models/bailing_hybrid.py
  • Reuses MLA absorbed-form (embed_q/unembed_out split) from deepseek_v3
  • Reuses linear-attention recurrence + GroupRMSNorm + group-limited MoE patterns from bailing_moe_linear; slope formula matches transformers exactly (no clamp on layer_idx-1)
  • sanitize() drops the trailing MTP layer, splits MLA kv_b_proj, and stacks MoE experts into SwitchGLU
  • make_cache() returns KVCache for MLA layers and ArraysCache(size=1) for linear-attention layers
  • Adds bailing_hybrid config to tests/test_models.py::test_all_models

Generated with the transformers-to-mlx skill. Tested via 4-bit conversion (208GB -> 55GB @ 4.501 bpw) on an M3 Ultra: dtype check passes, 600-token code and 2930-token long-form generations stay coherent end-to-end.

Adds support for inclusionAI/Ling-2.6-flash (BailingMoeV2_5ForCausalLM,
model_type=bailing_hybrid): a 104B / 7.4B-active hybrid that mixes MLA with
Lightning-style linear attention (1:7 ratio, layers 7/15/23/31 are MLA) and
a sigmoid noaux_tc MoE (256 experts, 1 shared, group-limited top-8).

- New self-contained model file mlx_lm/models/bailing_hybrid.py
- Reuses MLA absorbed-form (embed_q/unembed_out split) from deepseek_v3
- Reuses linear-attention recurrence + GroupRMSNorm + group-limited MoE
  patterns from bailing_moe_linear; slope formula matches transformers
  exactly (no clamp on layer_idx-1)
- sanitize() drops the trailing MTP layer, splits MLA kv_b_proj, and stacks
  MoE experts into SwitchGLU
- make_cache() returns KVCache for MLA layers and ArraysCache(size=1) for
  linear-attention layers
- Adds bailing_hybrid config to tests/test_models.py::test_all_models

Generated with the transformers-to-mlx skill. Tested via 4-bit conversion
(208GB -> 55GB @ 4.501 bpw) on an M3 Ultra: dtype check passes, 600-token
code and 2930-token long-form generations stay coherent end-to-end.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Comment thread mlx_lm/models/bailing_hybrid.py Outdated
Comment on lines +64 to +83
def recurrent_gla(
q: mx.array,
k: mx.array,
v: mx.array,
g: mx.array,
scale: float,
h: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
L = q.shape[2]
exp_g = mx.exp(g)[:, None, None].astype(q.dtype)
q = q * scale
outputs = []
for t in range(L):
q_t = q[:, :, t : t + 1]
k_t = k[:, :, t : t + 1]
v_t = v[:, :, t : t + 1]
h_up = k_t.transpose(0, 1, 3, 2) @ v_t
h = h_up if h is None else h * exp_g + h_up
outputs.append(q_t @ h)
return mx.concatenate(outputs, axis=2), h
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could optimize this with a metal kernel

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make it even faster! Go for it boss!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to build this and test performance, keep you posted 💪

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much faster! Thanks @kernelpool I see same pattern in the bailing_moe_linear.py, gonna apply optimization there too and test pre / post for both.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance worst in bailing_moe_linear.py, keeping existing implementation there.
Comparison for bailing_hybrid with 2k context:

  • pre metal kernel: 693/79 t/s mem 61.1GB, batch 8 tg 1014/218 tps 77.5 GB
  • post metal kernel: t/s 1357/80 t/s 59.8GB, batch 8 tg 1621/226 tps 66.2 GB 🔥

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

Replace the Python `for t in range(L)` recurrence in `recurrent_gla`
with a Metal kernel that fuses the entire gated linear-attention loop
into a single GPU dispatch (modeled after `gated_delta` and `rwkv7`).
Fall back to a compiled-step ops loop when Metal is unavailable or
when head dims don't fit the kernel's constraints
(`D % 32 == 0`, `Dv % 4 == 0`, `Dk == Dv`).

Microbench (B=1, H=32, T=512, D=128): ~13.5x faster than the loop
(12.2 ms -> 0.9 ms). Kernel vs. ops outputs agree to ~1.4e-6.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@ivanfioravanti
Copy link
Copy Markdown
Contributor Author

There are still some issues with longer prompts. I'm investigating.

fix: snapshot bailing hybrid cache offsets
@ivanfioravanti
Copy link
Copy Markdown
Contributor Author

Fixed! Testing with OpenCode and pi mono raised an issue. Cache offset logic was wrong.
It seems good to me now.

The recurrent_gla Metal kernel does not implement vjp, so backprop
through it fails with "Primitive::vjp Not implemented for CustomKernel"
when running DWQ (or any mx.value_and_grad path). Add a use_kernel flag
to recurrent_gla and pass `not self.training` from LinearAttention so the
differentiable ops fallback runs during training while inference still
uses the fast kernel. Matches the pattern used in qwen3_next, kimi_linear,
and qwen3_5.
Comment thread mlx_lm/models/bailing_hybrid.py Outdated

slopes = mx.array(slopes, dtype=mx.float32)
denom = max(1, self.num_hidden_layers - 1)
layer_factor = 1 - (self.layer_idx - 1) / denom + 1e-5
Copy link
Copy Markdown
Contributor

@kernelpool kernelpool Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
layer_factor = 1 - (self.layer_idx - 1) / denom + 1e-5
layer_factor = 1 - self.layer_idx / denom + 1e-5

This change matches vllm/sglang (the HF reference looks wrong) and improves perplexity in my testing.

Formula Perplexity (calibration_v5, 50 × 1024 tokens)
1 - layer_idx / (N-1) + 1e-5 5.959 ± 0.013
1 - (layer_idx - 1) / (N-1) + 1e-5 6.021 ± 0.013

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed! 5.54 vs 5.6 4bit-gs32 in my tests.

You rock @kernelpool 🚀

Updated the layer factor calculation as suggested by @kernelpool to mach vllm/sglang (HF reference looks wrong)
@ivanfioravanti
Copy link
Copy Markdown
Contributor Author

Starting some evals! This architecture looks amazing from performance point of view!

@ivanfioravanti
Copy link
Copy Markdown
Contributor Author

Ready for you super @angeloskath 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants