[transformers-to-mlx skill] Add Talkie (TalkieForCausalLM) model#1231
Open
warshanks wants to merge 3 commits intoml-explore:mainfrom
Open
[transformers-to-mlx skill] Add Talkie (TalkieForCausalLM) model#1231warshanks wants to merge 3 commits intoml-explore:mainfrom
warshanks wants to merge 3 commits intoml-explore:mainfrom
Conversation
Adds support for `lewtun/talkie-1930-13b-it-hf`, a 13B decoder-only transformer published with custom modeling code (no native transformers module). 40 layers, 40 heads, head_dim=128, hidden=5120, vocab=65540, max_pos=2048, rope_theta=1e6, bf16. Notable architecture quirks handled in the new model: - Custom RoPE convention: rotation by -theta (sign-flipped sin), so mx.fast.rope cannot be used; a TalkieRoPE class implements the reference _apply_rotary_emb formula. - Weightless RMSNorm everywhere (embedding output, pre-attention, pre-MLP, post-RoPE Q/K norm, final). Reduction in fp32. - Per-head Q gain applied after RoPE + QK norm. - Per-layer scalar gains: attn_gain / mlp_gain (init (2L)^-0.5) on the residual contributions; embed_skip (init 0.0) scales an embedding-skip residual into every block. - lm_head stored as raw (vocab, hidden) parameter plus a scalar lm_head_gain; sanitize() folds the gain into lm_head.weight so the quantizer sees a regular nn.Linear. Implementation, tests, and full conversion report produced with the transformers-to-mlx skill.
Author
Server's batched generation passes cache.offset as mx.array, but mx.arange() requires int/float args. Convert offset to int before calling arange. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Replace per-attention TalkieRoPE with precomputed cos/sin table, sliced by offset — avoids mx.arange type issues in server mode - Convert cache.offset to int for array slicing compatibility - Use mx.finfo(mx.float32).eps in RMS norm to match PyTorch - Add explicit dtype casts in HeadGain and ActGain for mixed precision - Dynamically extend RoPE table if sequence exceeds max_position_embeddings Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Author
|
I accidentally created this PR before noticing #1220. Not sure if I was affected by a GH bug or what. Keeping it open as this approach skips the extra dependency. Feel free to close, this was more of a learning exercise for me! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds support for
lewtun/talkie-1930-13b-it-hf, a custom 13B decoder-only transformer published withauto_mapmodeling code (no nativetransformers/models/talkiedirectory). Knowledge cutoff ≈1930.Architecture novelties
40 layers / 40 heads / head_dim 128 / hidden 5120 / intermediate 13696 / vocab 65540 / max_pos 2048 / rope_theta 1e6 / bf16.
_apply_rotary_embcomputes[[cos, sin], [-sin, cos]], the inverse of the standard HF/Llama convention.mx.fast.ropeuses the standard convention, so aTalkieRoPEmodule is required.lm_head. All in fp32 internally, then cast back. Verified empirically that torch'sF.rms_norm(x, (D,))with defaulteps=Nonebehaves aseps≈0, nottorch.finfo(dtype).eps.head_g[n_head]) applied after RoPE + QK-norm.attn_gain/mlp_gaininitialized to(2L)^-0.5scale the residual contributions;embed_skipinitialized to 0.0 scales a per-block skip from the first-norm embedding into every layer (x = x + embed_skip * e_x).lm_headwith scalar gain. Stored in the checkpoint as a raw(vocab, hidden)parameter plus a scalarlm_head_gain.w_g. The MLXsanitize()folds the gain intolm_head.weightso quantization sees a regularnn.Linear.1/sqrt(head_dim)softmax scale is kept (rather than the more common QK-norm scale=1).tokenizer_class = "TokenizersBackend"). Loads viaAutoTokenizer+trust_remote_code=True.Implementation
mlx_lm/models/talkie.py(single file, ~225 LOC). Mirrors the reference attribute names verbatim (attn.attn_query/attn_key/attn_value/attn_resid,mlp.mlp_gate/mlp_linear/mlp_resid,head_gain.head_g,attn_gain.a_g,mlp_gain.a_g,embed_skip.a_g) so checkpoint keys map without a remap. The onlysanitizestep is foldinglm_head_gainintolm_head.weight.No changes to shared
mlx-lminfrastructure were required.Tests
Generation
mlx_lm.generate --model lewtun/talkie-1930-13b-it-hf \ --prompt '<|user|>What are some cutting-edge new technologies and inventions changing the world today?<|end|><|assistant|>' \ --max-tokens 200 --temp 0Cinema, wireless telegraphy, John o'Groats — period-appropriate
content confirms the model is loaded and computing correctly.
Long-sequence (410-token greedy story, EOS-terminated)
mlx_lm.generate --model lewtun/talkie-1930-13b-it-hf \ --prompt '<|user|>Write a long story about an inventor in the year 1925.<|end|><|assistant|>' \ --max-tokens 1024 --temp 0Output is a coherent 410-token story arc (invention → exhibition → factory → wholesale → fortune → parliament → death). No repetition or RoPE-degradation patterns.
Output dtype
Numerical comparison vs transformers (bf16, CPU, 94-token paragraph)
Comparison source (modified from skill's
compare_predictions.pyto use bf16 transformers for memory budget):compare_talkie.py
Layer-by-layer (47-token prompt, abs/rel diffs vs HF bf16)
Mean abs diff grows linearly with depth (~200× from layer 0 to 39), consistent with bf16 noise accumulating on a residual stream whose magnitude grows over layers (no per-layer normalization on the residual itself). The 64.0 max at layer 39 is a single-position outlier; the final RMSNorm collapses it to 3.2 max, and the logits row sits at 2.0 max / 0.1 mean — within typical transformers/MLX bf16 disagreement.
Quantization
Bare q4 is unusually weak on Talkie. The combination of weightless norms, per-layer scalar gains, and the small 2048-token training horizon amplifies quant noise. The mixed-q4 recipe (
lm_head=q8 + embed=bf16 + sensitivity-flagged blocks 14/37/38=q8 + rest=q4) recovers short-form coherence; DWQ-calibrated q4 is the strongest 4-bit option (val loss 0.037 after 512 iterations at default LR=1e-6).mlx_lm.awqraisesNotImplementedError: AWQ support for talkie models NYI. AWQ scales must be absorbed into the previous module's weight; talkie's weightless RMSNorms have no weight to absorb into. Adding AWQ would require either a customAWQConfigmatching talkie's structure or converting the weightless norms into learned norms initialized to ones.DWQ command
Note: an initial run with
--learning-rate 5e-5(50× the default) diverged — final val loss 2.07 vs initial 0.25. Default LR is mandatory for talkie; the per-layer scalar gains amplify gradient updates in unusual ways.Skill feedback (lessons captured back into the skill)
F.rms_norm(x, (D,))with defaulteps=Nonebehaves aseps≈0, nottorch.finfo(dtype).eps. Verified empirically.auto_map/trust_remote_code) require fetchingmodeling_*.pyandconfiguration_*.pydirectly — they aren't intransformers/models/.lm_head_gain) intolm_head.weightat sanitize time is cleaner than tracking the scalar at runtime, and works correctly through quantization.mx.fast.rope; check_apply_rotary_embsource for sign of sin before assuming the standard convention.mixed_*_*quant recipes hardcodedown_proj/v_projpatterns and won't match models with custom naming. Use a customquant_predicatecallable instead.AWQConfigplus restructure or a different calibration method.