Description
We would like to observe the effect of QK Normalization (https://arxiv.org/abs/2010.04245) by conducting speedruns of Qwen 3 mirroring the Llama runs in the muon sweep. As a main difference between Qwen3 and Llama is the QK norm (sliding window not implemented yet)1, the Llama runs there can be controls without doing separate ablations.
Update 20250910
After talking to Helw150, we decided to do a speedrun sweep like #1405, except using the Qwen 3 configs. This will be run on the us-east5-a cluster.
Hypothesis or Goal
Hypothesis: QK norm increases compute efficiency (lower C4-EN BPB for the same compute budget in pairwise comparisons with the speedruns in #1405).
Links
Results
We find that Qwen 3 with QK Norm resulted in slightly lower BPB for the same number of steps, but longer training time & consequently FLOPs increase, which is substantial for smaller runs.
Draft (outdated)
Run speedruns of Qwen 3 (tentatively 0.6B) with and without QK Norm enabled
QK norm is implemented in Levanter's Attention layer, but I don't think we can turn off just QK norm. Here are what I propose to do without Levanter changes:
# match https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json
qwen3_06b = Qwen3Config(
seq_len=4096, # speedrun-friendly; HF uses 40960
hidden_dim=1024,
intermediate_dim=3072,
num_layers=28,
num_heads=16,
num_kv_heads=8,
head_dim=128, # override otherwise 1024 // 16
rope=DefaultRotaryEmbeddingsConfig(theta=1_000_000.0),
tie_word_embeddings=True,
layer_norm_epsilon=1e-6,
)
@dataclasses.dataclass(frozen=True)
class Qwen3NoQKNormConfig(Qwen3Config):
def attention_config(self):
cfg = super().attention_config()
return dataclasses.replace(cfg, qk_norm=None)
qwen3_06b_no_qk_norm = Qwen3NoQKNormConfig(**dataclasses.asdict(qwen3_06b))
I propose to use v4-128. Run info estimates:
The rough estimated compute (calculated as (total model FLOPs / Assumed MFU)) for your run is probably between:
* 2.72e+19 FLOPs assuming an MFU of 0.5, and
* 6.81e+19 FLOPs assuming an MFU of 0.2.
This is calculated based on assumed MFU values and can be used as a rough estimate to guide your config/training setup.
Hardware and Model FLOPS Information:
Number of devices: 64
Device FLOPs: 2.75e+14 FLOP/s
Total peak hardware FLOPs: 1.76e+16 FLOP/s
Model FLOPs: 1.36e+19 FLOP
Model size: 615.05 million parameters
Description
We would like to observe the effect of QK Normalization (https://arxiv.org/abs/2010.04245) by conducting speedruns of Qwen 3 mirroring the Llama runs in the muon sweep. As a main difference between Qwen3 and Llama is the QK norm (sliding window not implemented yet)1, the Llama runs there can be controls without doing separate ablations.
Update 20250910
After talking to Helw150, we decided to do a speedrun sweep like #1405, except using the Qwen 3 configs. This will be run on the us-east5-a cluster.
Hypothesis or Goal
Hypothesis: QK norm increases compute efficiency (lower C4-EN BPB for the same compute budget in pairwise comparisons with the speedruns in #1405).
Links
Results
We find that Qwen 3 with QK Norm resulted in slightly lower BPB for the same number of steps, but longer training time & consequently FLOPs increase, which is substantial for smaller runs.
Draft (outdated)
Run speedruns of Qwen 3 (tentatively 0.6B) with and without QK Norm enabledQK norm is implemented in Levanter's Attention layer, but I don't think we can turn off just QK norm. Here are what I propose to do without Levanter changes:
I propose to use
v4-128. Run info estimates:Footnotes
check https://github.com/stanford-crfm/levanter/blob/974733b779f8b1014d1457adfb6ec91316c8cc3a/src/levanter/models/qwen.py#L377 ↩