Skip to content

Commit 2770562

Browse files
authored
Fix RoPE with Yarn (#85)
1 parent 3871dde commit 2770562

File tree

4 files changed

+156
-37
lines changed

4 files changed

+156
-37
lines changed

evaluation/evaluate.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import logging
66
from pathlib import Path
7-
from typing import Optional
7+
from typing import Optional, Any
88

99
import torch
1010
from datasets import load_dataset
@@ -103,6 +103,8 @@ def evaluate(
103103
max_context_length: Optional[int] = None,
104104
compress_questions: bool = False,
105105
key_channel_compression_ratio: float = 0.5,
106+
rope_scaling: Optional[dict] = None,
107+
max_position_embeddings: Optional[int] = None,
106108
):
107109
"""
108110
Evaluate a model on a dataset using a press and save the results
@@ -131,6 +133,14 @@ def evaluate(
131133
Whether to compress the questions as well, by default False
132134
key_channel_compression_ratio : float, optional
133135
key Channel Compression ratio for the channel press, by default 0.5
136+
rope_scaling : dict, optional
137+
RoPE-scaling configuration dictionary passed to
138+
model config's `rope_scaling field.
139+
(e.g. {"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768});
140+
by default None. If set, you **must** also provide ``max_position_embeddings``.
141+
max_position_embeddings : int, optional
142+
The value to set for ``max_position_embeddings`` in the model config when ``rope_scaling`` is used.
143+
Required if ``rope_scaling`` is not ``None``; ignored otherwise.
134144
"""
135145

136146
assert dataset in DATASET_DICT, f"No dataset found for {dataset}"
@@ -184,7 +194,7 @@ def evaluate(
184194
press.compression_ratio = compression_ratio # type:ignore[attr-defined]
185195

186196
# Initialize pipeline with the correct attention implementation
187-
model_kwargs = {"torch_dtype": "auto"}
197+
model_kwargs: dict[str, Any] = {"torch_dtype": "auto"}
188198
if isinstance(press, ObservedAttentionPress):
189199
model_kwargs["attn_implementation"] = "eager"
190200
else:
@@ -194,6 +204,16 @@ def evaluate(
194204
model_kwargs["attn_implementation"] = "flash_attention_2"
195205
except ImportError:
196206
pass
207+
if rope_scaling is not None:
208+
if max_position_embeddings is None:
209+
raise ValueError("max_position_embeddings must be given when rope_scaling is used")
210+
211+
model_kwargs.update(
212+
{
213+
"max_position_embeddings": max_position_embeddings,
214+
"rope_scaling": rope_scaling,
215+
}
216+
)
197217

198218
if device == "auto":
199219
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)

kvpress/presses/finch_press.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
import torch
99
from torch.nn import functional as F
10-
from transformers.models.llama.modeling_llama import rotate_half
1110

1211
from kvpress.presses.base_press import BasePress
1312
from kvpress.presses.snapkv_press import SnapKVPress
13+
from kvpress.presses.key_rerotation_press import KeyRerotationPress
1414

1515

1616
@dataclass
@@ -93,18 +93,12 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
9393
chunk_indices = i + chunk_scores.topk(n_kept, dim=-1).indices
9494
indices.append(chunk_indices)
9595
indices = torch.cat(indices, dim=-1)
96-
97-
indices = torch.sort(indices, dim=2).values
98-
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
99-
100-
# Rerotate keys
10196
if self.rerotate_keys:
102-
cos, sin = kwargs["position_embeddings"]
103-
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
104-
keys = keys.gather(2, indices).contiguous()
105-
cos, sin = cos[:, : indices.shape[2]], sin[:, : indices.shape[2]]
106-
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
97+
indices = torch.sort(indices, dim=2).values
98+
keys = KeyRerotationPress.rerotate_keys(module, indices, keys)
99+
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
107100
else:
101+
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
108102
keys = keys.gather(2, indices).contiguous()
109103

110104
values = values.gather(2, indices).contiguous()

kvpress/presses/key_rerotation_press.py

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,85 @@ class KeyRerotationPress(BasePress):
3030
def __post_init__(self):
3131
assert isinstance(self.press, ScorerPress)
3232

33+
@staticmethod
34+
def _rerotate_cos_sin(x, inv_freq, selected_positions):
35+
"""
36+
Compute cosine and sine rotary positional embeddings required to
37+
re-rotate pruned keys back into the canonical RoPE space.
38+
39+
Parameters
40+
----------
41+
x : torch.Tensor
42+
Any key-like tensor that provides ``dtype`` and ``device``.
43+
Shape ``(bsz, num_key_value_heads, q_len, d)``.
44+
inv_freq : torch.Tensor
45+
``module.rotary_emb.inv_freq``. Shape ``(d//2,)``.
46+
selected_positions : torch.Tensor
47+
Indices of the *kept* tokens.
48+
Shape ``(bsz, num_key_value_heads, n_kept)``.
49+
50+
Returns
51+
-------
52+
cos, sin : torch.Tensor
53+
Cosine and sine embeddings, each of shape
54+
``(bsz, num_key_value_heads, n_kept, d)``, matching ``dtype``/``device`` of ``x``.
55+
"""
56+
bsz, num_key_value_heads, n_kept = selected_positions.shape
57+
device = selected_positions.device
58+
device_type = x.device.type
59+
dtype = x.dtype
60+
# Original positional indices
61+
idx = torch.arange(0, n_kept, device=device) # (n_kept,)
62+
idx = idx.unsqueeze(0) # (1, n_kept)
63+
inv_freq = inv_freq[None, None, :, None].float().expand(bsz, num_key_value_heads, -1, 1)
64+
idx = idx[:, None, :].float().expand(bsz, num_key_value_heads, n_kept)
65+
# Compute delta between original and selected positions
66+
delta_pos = idx - selected_positions # (bsz, num_key_value_heads, n_kept)
67+
delta_pos = delta_pos.unsqueeze(2) # (bsz, num_key_value_heads, 1, n_kept)
68+
69+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
70+
71+
with torch.autocast(device_type=device_type, enabled=False):
72+
# Compute the new freq by scaling inv_freq by delta
73+
freqs = delta_pos.float() * inv_freq.float() # (bsz, num_key_value_heads, d//2, n_kept)
74+
freqs = freqs.transpose(2, 3) # (bsz, num_key_value_heads, n_kept, d//2)
75+
emb = torch.cat((freqs, freqs), dim=-1)
76+
# Compute cosine and sine required to re-rotate keys to selected positions
77+
cos = emb.cos().contiguous()
78+
sin = emb.sin().contiguous()
79+
return cos.to(dtype=dtype), sin.to(dtype=dtype)
80+
81+
@staticmethod
82+
def rerotate_keys(
83+
module: nn.Module,
84+
indices: torch.Tensor,
85+
keys: torch.Tensor,
86+
) -> torch.Tensor:
87+
"""
88+
Rerotate keys to have a uniform RoPE representation of keys after pruning.
89+
90+
Parameters
91+
----------
92+
module : nn.Module
93+
The model module containing the rotary embedding.
94+
indices : torch.Tensor
95+
Indices of the kept tokens after pruning.
96+
keys : torch.Tensor
97+
The keys tensor to be rerotated.
98+
99+
Returns
100+
-------
101+
torch.Tensor
102+
The rerotated keys tensor of shape
103+
``(bsz, num_heads, n_kept, d)``.
104+
"""
105+
new_cos, new_sin = KeyRerotationPress._rerotate_cos_sin(keys,
106+
module.rotary_emb.inv_freq,
107+
indices)
108+
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
109+
keys = keys.gather(2, indices).contiguous()
110+
return (keys * new_cos) + (rotate_half(keys) * new_sin)
111+
33112
def compress(
34113
self,
35114
module: nn.Module,
@@ -50,22 +129,7 @@ def compress(
50129
n_kept = int(q_len * (1 - self.press.compression_ratio))
51130
indices = scores.topk(n_kept, dim=-1).indices
52131
indices = torch.sort(indices, dim=2).values
132+
keys = self.rerotate_keys(module, indices, keys)
53133
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
54-
55-
cos, sin = kwargs["position_embeddings"]
56-
# Rerotate as follows
57-
# 1. keys = RoPE(W_k * hidden_states)
58-
# 2. keys_unrotated = RoPE^-1(keys)
59-
# 3. keys_pruned = prune(keys_unrotated)
60-
# 4. keys = RoPE(keys_pruned)
61-
62-
# 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below
63-
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
64-
# 3. Prune keys
65-
keys = keys.gather(2, indices).contiguous()
66-
# 4. Apply RoPE
67-
cos, sin = cos[:, :n_kept], sin[:, :n_kept]
68-
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
69-
70134
values = values.gather(2, indices).contiguous()
71135
return keys, values

tests/presses/test_key_rerotation_press_rope.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,73 @@
11
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
3+
#
4+
# Extended to test both the *default* and the *YaRN-scaled* rotary-embedding
5+
# variants with the smallest possible code changes.
46

57
import inspect
68
from dataclasses import dataclass
9+
from copy import deepcopy
710

811
import pytest
912
import torch
1013
from torch import nn
11-
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half
14+
from transformers.models.llama.modeling_llama import (
15+
LlamaAttention,
16+
LlamaForCausalLM,
17+
LlamaRotaryEmbedding,
18+
rotate_half,
19+
)
20+
from transformers import Gemma3ForCausalLM
1221

1322
from kvpress import KeyRerotationPress, ScorerPress
1423
from tests.fixtures import unit_test_model # noqa: F401
1524

1625

26+
@pytest.mark.parametrize("rope_variant", ["default", "yarn"])
1727
@pytest.mark.parametrize("precision", ["full", "half"])
18-
def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM, precision): # noqa: F811
28+
def test_rerotate_keys_is_matches_reference_implementation(
29+
unit_test_model: LlamaForCausalLM, # noqa: F811
30+
rope_variant,
31+
precision,
32+
):
1933
"""
20-
Compare KeyRerotationPress' rerotation of keys with the reference implementation.
21-
In the reference implementation, we are computing
34+
Compare KeyRerotationPress' rerotation of keys with the reference
35+
implementation.
36+
37+
Reference path:
2238
1. keys = W_k * hidden_states
2339
2. keys_pruned = prune(keys)
2440
3. keys = RoPE(keys_pruned)
41+
42+
Press path:
43+
1. keys = W_k * hidden_states
44+
2. keys = RoPE(keys)
45+
3. keys_pruned = KeyRerotationPress.rerotate_keys(...)
2546
"""
47+
if rope_variant == "yarn":
48+
layer0 = unit_test_model.model.layers[0]
49+
cfg = deepcopy(layer0.self_attn.config)
50+
cfg.rope_scaling = {
51+
"factor": 4.0,
52+
"original_max_position_embeddings": 32768,
53+
"rope_type": "yarn",
54+
}
55+
cfg.max_position_embeddings = 131072
56+
try:
57+
unit_test_model.model.rotary_emb = LlamaRotaryEmbedding(cfg, device=unit_test_model.device)
58+
except KeyError:
59+
pytest.skip("YaRN rotary-embedding not available in this transformers version.")
60+
61+
for layer in unit_test_model.model.layers:
62+
if isinstance(unit_test_model, Gemma3ForCausalLM) and layer.is_sliding:
63+
# Skip layers with sliding window attention, only for Gemma3
64+
continue
65+
layer.self_attn.rotary_emb = unit_test_model.model.rotary_emb
66+
2667
if precision == "half" and torch.cuda.is_available():
2768
unit_test_model = unit_test_model.cuda().half()
28-
elif precision == "half" and not torch.cuda.is_available():
29-
pytest.skip("Half precision test is skipped because CUDA is not available.")
69+
elif precision == "half":
70+
pytest.skip("Half-precision test skipped because CUDA is not available.")
3071

3172
original_press = RandomPressStoreIndices(compression_ratio=0.5)
3273
key_rerotation_press = KeyRerotationPress(press=original_press)
@@ -47,7 +88,7 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam
4788
keys,
4889
values,
4990
attentions=None,
50-
kwargs={"position_embeddings": get_rope_embeddings(module, keys)},
91+
kwargs={},
5192
)
5293

5394
indices = original_press.indices

0 commit comments

Comments
 (0)