Skip to content

Qwen (QK Norm) Speedruns #1572

@Calvin-Xu

Description

@Calvin-Xu

Description

We would like to observe the effect of QK Normalization (https://arxiv.org/abs/2010.04245) by conducting speedruns of Qwen 3 mirroring the Llama runs in the muon sweep. As a main difference between Qwen3 and Llama is the QK norm (sliding window not implemented yet)1, the Llama runs there can be controls without doing separate ablations.

Update 20250910

After talking to Helw150, we decided to do a speedrun sweep like #1405, except using the Qwen 3 configs. This will be run on the us-east5-a cluster.

Hypothesis or Goal

Hypothesis: QK norm increases compute efficiency (lower C4-EN BPB for the same compute budget in pairwise comparisons with the speedruns in #1405).

Links

Results

We find that Qwen 3 with QK Norm resulted in slightly lower BPB for the same number of steps, but longer training time & consequently FLOPs increase, which is substantial for smaller runs.

Image

Draft (outdated)

Run speedruns of Qwen 3 (tentatively 0.6B) with and without QK Norm enabled

QK norm is implemented in Levanter's Attention layer, but I don't think we can turn off just QK norm. Here are what I propose to do without Levanter changes:

# match https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json
qwen3_06b = Qwen3Config(
    seq_len=4096,  # speedrun-friendly; HF uses 40960
    hidden_dim=1024,
    intermediate_dim=3072,
    num_layers=28,
    num_heads=16,
    num_kv_heads=8,
    head_dim=128,  # override otherwise 1024 // 16
    rope=DefaultRotaryEmbeddingsConfig(theta=1_000_000.0),
    tie_word_embeddings=True,
    layer_norm_epsilon=1e-6,
)


@dataclasses.dataclass(frozen=True)
class Qwen3NoQKNormConfig(Qwen3Config):
    def attention_config(self):
        cfg = super().attention_config()
        return dataclasses.replace(cfg, qk_norm=None)


qwen3_06b_no_qk_norm = Qwen3NoQKNormConfig(**dataclasses.asdict(qwen3_06b))

I propose to use v4-128. Run info estimates:

The rough estimated compute (calculated as (total model FLOPs / Assumed MFU)) for your run is probably between:
 * 2.72e+19 FLOPs assuming an MFU of 0.5, and
 * 6.81e+19 FLOPs assuming an MFU of 0.2.

This is calculated based on assumed MFU values and can be used as a rough estimate to guide your config/training setup.

Hardware and Model FLOPS Information:
Number of devices: 64
Device FLOPs: 2.75e+14 FLOP/s
Total peak hardware FLOPs: 1.76e+16 FLOP/s
Model FLOPs: 1.36e+19 FLOP
Model size: 615.05 million parameters

Footnotes

  1. check https://github.com/stanford-crfm/levanter/blob/974733b779f8b1014d1457adfb6ec91316c8cc3a/src/levanter/models/qwen.py#L377

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions