Skip to content

[transformers-to-mlx skill] Add Talkie (TalkieForCausalLM) model#1231

Open
warshanks wants to merge 3 commits intoml-explore:mainfrom
warshanks:add-talkie
Open

[transformers-to-mlx skill] Add Talkie (TalkieForCausalLM) model#1231
warshanks wants to merge 3 commits intoml-explore:mainfrom
warshanks:add-talkie

Conversation

@warshanks
Copy link
Copy Markdown

Adds support for lewtun/talkie-1930-13b-it-hf, a custom 13B decoder-only transformer published with auto_map modeling code (no native transformers/models/talkie directory). Knowledge cutoff ≈1930.

Implementation, tests, and the report below were produced with the transformers-to-mlx skill. Generation, dtype, numerical, layer-by-layer, long-sequence, and quantization tests were all run via the skill's bundled scripts.

Architecture novelties

40 layers / 40 heads / head_dim 128 / hidden 5120 / intermediate 13696 / vocab 65540 / max_pos 2048 / rope_theta 1e6 / bf16.

  • Custom RoPE (rotation by −θ). The reference _apply_rotary_emb computes
    y1 =  x1*cos + x2*sin
    y2 = -x1*sin + x2*cos
    
    i.e. rotation matrix [[cos, sin], [-sin, cos]], the inverse of the standard HF/Llama convention. mx.fast.rope uses the standard convention, so a TalkieRoPE module is required.
  • Weightless RMSNorm everywhere — embedding output, pre-attention, pre-MLP, post-RoPE Q/K norm, final pre-lm_head. All in fp32 internally, then cast back. Verified empirically that torch's F.rms_norm(x, (D,)) with default eps=None behaves as eps≈0, not torch.finfo(dtype).eps.
  • Per-head Q gain (head_g[n_head]) applied after RoPE + QK-norm.
  • Per-layer scalar gains. attn_gain / mlp_gain initialized to (2L)^-0.5 scale the residual contributions; embed_skip initialized to 0.0 scales a per-block skip from the first-norm embedding into every layer (x = x + embed_skip * e_x).
  • lm_head with scalar gain. Stored in the checkpoint as a raw (vocab, hidden) parameter plus a scalar lm_head_gain.w_g. The MLX sanitize() folds the gain into lm_head.weight so quantization sees a regular nn.Linear.
  • QK-norm + standard scale. Q/K are RMSNormed before SDPA but the default 1/sqrt(head_dim) softmax scale is kept (rather than the more common QK-norm scale=1).
  • Custom tokenizer backend (tokenizer_class = "TokenizersBackend"). Loads via AutoTokenizer + trust_remote_code=True.

Implementation

mlx_lm/models/talkie.py (single file, ~225 LOC). Mirrors the reference attribute names verbatim (attn.attn_query/attn_key/attn_value/attn_resid, mlp.mlp_gate/mlp_linear/mlp_resid, head_gain.head_g, attn_gain.a_g, mlp_gain.a_g, embed_skip.a_g) so checkpoint keys map without a remap. The only sanitize step is folding lm_head_gain into lm_head.weight.

No changes to shared mlx-lm infrastructure were required.

Tests

Generation

mlx_lm.generate --model lewtun/talkie-1930-13b-it-hf \
  --prompt '<|user|>What are some cutting-edge new technologies and inventions changing the world today?<|end|><|assistant|>' \
  --max-tokens 200 --temp 0
New technologies and inventions which are changing the world today
include the cinema, wireless telegraphy, and telephony, air transport,
and road transport. By means of the cinema, an entire play can be shown
simultaneously in a thousand different places; by wireless, speech can
be transmitted to great distances; by air transport, a passenger can be
carried from London to Paris in a few hours; and by road transport, he
can travel at high speed from Land's End to John o' Groats.

Cinema, wireless telegraphy, John o'Groats — period-appropriate
content confirms the model is loaded and computing correctly.

Long-sequence (410-token greedy story, EOS-terminated)

mlx_lm.generate --model lewtun/talkie-1930-13b-it-hf \
  --prompt '<|user|>Write a long story about an inventor in the year 1925.<|end|><|assistant|>' \
  --max-tokens 1024 --temp 0

Output is a coherent 410-token story arc (invention → exhibition → factory → wholesale → fortune → parliament → death). No repetition or RoPE-degradation patterns.

Output dtype

python scripts/resolve_dtype.py /path/to/model
# dtype: bfloat16
# source: config.dtype

python scripts/check_dtype.py /path/to/model bfloat16
# {"output_dtype": "mlx.core.bfloat16", "expected_dtype": "mlx.core.bfloat16", "match": true}

Numerical comparison vs transformers (bf16, CPU, 94-token paragraph)

Logits diff:    max=2.0000   mean=0.0785   median=0.0625
Last-pos top-10 (HF):  [374, 650, 1155, 31616, 4017, 1407, 499, 5617, 2718, 461]
Last-pos top-10 (MLX): [374, 650, 1155, 31616, 4017, 1407, 499, 5617, 2718, 461]
Top-10 overlap:   10/10
Top-1 agreement:  98.9%   (across all 94 positions)

Comparison source (modified from skill's compare_predictions.py to use bf16 transformers for memory budget):

compare_talkie.py
import sys, gc, numpy as np, torch, mlx.core as mx
from transformers import AutoTokenizer, AutoModelForCausalLM
from mlx_lm import load

MODEL = sys.argv[1]
PROMPT = (
    "The Eiffel Tower is one of the most iconic landmarks in the world. "
    "Built between 1887 and 1889, it stands at 330 meters tall and was the tallest "
    "man-made structure for 41 years. Designed by Gustave Eiffel for the 1889 World's "
    "Fair, it has become a symbol of Paris and France. The tower attracts millions of "
    "visitors each year and offers panoramic views of the city from its three observation decks."
)

tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True
).eval()
ids = tok.encode(PROMPT)
with torch.no_grad():
    hf_logits = hf_model(torch.tensor([ids])).logits.float().numpy()
del hf_model; gc.collect()

mlx_model, _ = load(MODEL)
mlx_logits = mlx_model(mx.array([ids]))
mx.eval(mlx_logits)
mlx_np = np.array(mlx_logits.astype(mx.float32))

diff = np.abs(hf_logits - mlx_np)
print(f"max={diff.max():.4f} mean={diff.mean():.4f} median={np.median(diff):.4f}")
top1 = (hf_logits.argmax(-1) == mlx_np.argmax(-1)).mean()
print(f"top-1 agreement: {top1*100:.1f}%")

Layer-by-layer (47-token prompt, abs/rel diffs vs HF bf16)

layer abs max abs mean
0 0.125 0.0015
1 0.125 0.0009
10 0.344 0.0012
20 0.297 0.0051
30 5.000 0.0454
39 64.000 0.296
post-norm 3.219 0.027
logits 2.000 0.100

Mean abs diff grows linearly with depth (~200× from layer 0 to 39), consistent with bf16 noise accumulating on a residual stream whose magnitude grows over layers (no per-layer normalization on the residual itself). The 64.0 max at layer 39 is a single-position outlier; the final RMSNorm collapses it to 3.2 max, and the logits row sits at 2.0 max / 0.1 mean — within typical transformers/MLX bf16 disagreement.

Quantization

Bits Group bpw Quality
4 64 4.5 degrades — repetition loop
4 32 5.0 poor — list-style repetition
4 (mixed) 64 5.18 short OK; long-seq fragile
6 64 6.5 good — coherent, era-appropriate
8 64 8.5 good — coherent, era-appropriate
4 (DWQ) 64 4.5 clean short + long

Bare q4 is unusually weak on Talkie. The combination of weightless norms, per-layer scalar gains, and the small 2048-token training horizon amplifies quant noise. The mixed-q4 recipe (lm_head=q8 + embed=bf16 + sensitivity-flagged blocks 14/37/38=q8 + rest=q4) recovers short-form coherence; DWQ-calibrated q4 is the strongest 4-bit option (val loss 0.037 after 512 iterations at default LR=1e-6).

mlx_lm.awq raises NotImplementedError: AWQ support for talkie models NYI. AWQ scales must be absorbed into the previous module's weight; talkie's weightless RMSNorms have no weight to absorb into. Adding AWQ would require either a custom AWQConfig matching talkie's structure or converting the weightless norms into learned norms initialized to ones.

DWQ command

mlx_lm.dwq --model lewtun/talkie-1930-13b-it-hf \
  --mlx-path talkie-mlx-dwq --bits 4 --group-size 64 \
  --num-samples 512 --max-seq-length 512 --batch-size 1 \
  --learning-rate 1e-6 --grad-checkpoint

Note: an initial run with --learning-rate 5e-5 (50× the default) diverged — final val loss 2.07 vs initial 0.25. Default LR is mandatory for talkie; the per-layer scalar gains amplify gradient updates in unusual ways.

Skill feedback (lessons captured back into the skill)

  • F.rms_norm(x, (D,)) with default eps=None behaves as eps≈0, not torch.finfo(dtype).eps. Verified empirically.
  • Custom modeling repos (auto_map / trust_remote_code) require fetching modeling_*.py and configuration_*.py directly — they aren't in transformers/models/.
  • Folding scalar weight-gains (lm_head_gain) into lm_head.weight at sanitize time is cleaner than tracking the scalar at runtime, and works correctly through quantization.
  • RoPE sign convention is invisible at first glance and not parameterized in mx.fast.rope; check _apply_rotary_emb source for sign of sin before assuming the standard convention.
  • Built-in mixed_*_* quant recipes hardcode down_proj/v_proj patterns and won't match models with custom naming. Use a custom quant_predicate callable instead.
  • AWQ assumes weight-bearing norms upstream of every quantized projection; weightless-norm models need either a custom AWQConfig plus restructure or a different calibration method.
  • DWQ default learning rate (1e-6) is non-negotiable for models with per-layer scalar gains — even 50× higher diverges.

Adds support for `lewtun/talkie-1930-13b-it-hf`, a 13B decoder-only
transformer published with custom modeling code (no native transformers
module). 40 layers, 40 heads, head_dim=128, hidden=5120, vocab=65540,
max_pos=2048, rope_theta=1e6, bf16.

Notable architecture quirks handled in the new model:
- Custom RoPE convention: rotation by -theta (sign-flipped sin), so
  mx.fast.rope cannot be used; a TalkieRoPE class implements the
  reference _apply_rotary_emb formula.
- Weightless RMSNorm everywhere (embedding output, pre-attention,
  pre-MLP, post-RoPE Q/K norm, final). Reduction in fp32.
- Per-head Q gain applied after RoPE + QK norm.
- Per-layer scalar gains: attn_gain / mlp_gain (init (2L)^-0.5) on
  the residual contributions; embed_skip (init 0.0) scales an
  embedding-skip residual into every block.
- lm_head stored as raw (vocab, hidden) parameter plus a scalar
  lm_head_gain; sanitize() folds the gain into lm_head.weight so
  the quantizer sees a regular nn.Linear.

Implementation, tests, and full conversion report produced with the
transformers-to-mlx skill.
@warshanks
Copy link
Copy Markdown
Author

warshanks and others added 2 commits April 30, 2026 15:01
Server's batched generation passes cache.offset as mx.array, but
mx.arange() requires int/float args. Convert offset to int before
calling arange.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Replace per-attention TalkieRoPE with precomputed cos/sin table,
  sliced by offset — avoids mx.arange type issues in server mode
- Convert cache.offset to int for array slicing compatibility
- Use mx.finfo(mx.float32).eps in RMS norm to match PyTorch
- Add explicit dtype casts in HeadGain and ActGain for mixed precision
- Dynamically extend RoPE table if sequence exceeds max_position_embeddings

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@warshanks
Copy link
Copy Markdown
Author

I accidentally created this PR before noticing #1220. Not sure if I was affected by a GH bug or what. Keeping it open as this approach skips the extra dependency. Feel free to close, this was more of a learning exercise for me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant