Skip to content

Commit 0dafec8

Browse files
SimJegmaxjeblick
authored andcommitted
Add DuoAttention on the fly (#63)
Signed-off-by: Max Jeblick <[email protected]>
1 parent 6c90873 commit 0dafec8

File tree

4 files changed

+87
-4
lines changed

4 files changed

+87
-4
lines changed

evaluation/evaluate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
StreamingLLMPress,
3535
ThinKPress,
3636
TOVAPress,
37+
QFilterPress,
3738
)
3839

3940
logger = logging.getLogger(__name__)
@@ -75,9 +76,10 @@
7576
"think": ThinKPress(),
7677
"tova": TOVAPress(),
7778
"duo_attention": DuoAttentionPress(),
79+
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
7880
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
81+
"qfilter": QFilterPress(),
7982
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
80-
"full_kv": ExpectedAttentionPress(0.0),
8183
}
8284

8385

kvpress/presses/duo_attention_press.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from cachetools import cached, LRUCache # type: ignore[import-untyped]
45
from contextlib import contextmanager
56
from dataclasses import dataclass, field
67
from io import StringIO
78

89
import numpy as np
910
import requests # type: ignore[import-untyped]
1011
import torch
12+
from datasets import load_dataset
13+
from transformers import AutoTokenizer
14+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
1115

1216
from kvpress.presses.base_press import BasePress
1317

@@ -20,6 +24,8 @@
2024
"mistralai/Mistral-7B-Instruct-v0.3": "Mistral-7B-Instruct-v0.3/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
2125
}
2226

27+
cache = LRUCache(maxsize=128)
28+
2329

2430
@dataclass
2531
class DuoAttentionPress(BasePress):
@@ -29,12 +35,15 @@ class DuoAttentionPress(BasePress):
2935
Splits attention heads into two types:
3036
- Retrieval heads: use the full KV cache
3137
- Streaming heads: use only sink and recent tokens.
32-
33-
Head classification is based on scores loaded from https://github.com/mit-han-lab/duo-attention/
3438
The higher the head_compression_ratio, the more streaming heads are used.
39+
40+
Head classification is based on scores.
41+
- If on_the_fly_scoring=False, scores are loaded from https://github.com/mit-han-lab/duo-attention/
42+
- (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly
3543
"""
3644

3745
head_compression_ratio: float = 0.0
46+
on_the_fly_scoring: bool = False
3847
compression_ratio_: float = field(init=False, default=None)
3948
recent_size: int = field(init=False, default=None)
4049
sink_size: int = field(init=False, default=None)
@@ -45,7 +54,10 @@ def __post_init_from_model__(self, model):
4554
Initialize sink_size, recent_size, and streaming_mask from a model
4655
"""
4756
# Load attention pattern from the DuoAttention repo
48-
self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model)
57+
if self.on_the_fly_scoring:
58+
self.sink_size, self.recent_size, head_scores = 128, 256, duo_attention_on_the_fly(model)
59+
else:
60+
self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model)
4961

5062
# Define retrieval and streaming heads through a binary mask
5163
n_pruned = round(head_scores.size * self.head_compression_ratio)
@@ -82,6 +94,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
8294
return keys, values
8395

8496
@staticmethod
97+
@cached(cache, key=lambda model: model.config.name_or_path)
8598
def load_attention_pattern(model):
8699
"""
87100
Load the attention pattern from the DuoAttention repo
@@ -108,3 +121,68 @@ def __call__(self, model):
108121
self.__post_init_from_model__(model)
109122
with super().__call__(model):
110123
yield
124+
125+
126+
@cached(cache, key=lambda model, num_samples=50, q_len=500: (model.config.name_or_path, num_samples, q_len))
127+
def duo_attention_on_the_fly(model, num_samples=50, q_len=500):
128+
"""
129+
New experimental method to quickly compute DuoAttention scores:
130+
- Compute the mean query and key on num_samples random samples from BookSum
131+
- Repeat the mean query and key q_len times and apply RoPE to get (Q, K)
132+
- Compute the attention weights for (Q[-1], K) and compute the "area under the cumulated attention curve"
133+
These scores could also be saved to avoid recomputing them but this method is still experimental
134+
"""
135+
136+
tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
137+
num_heads = model.config.num_attention_heads
138+
num_key_value_heads = model.config.num_key_value_heads
139+
num_key_value_groups = num_heads // num_key_value_heads
140+
141+
# Load data
142+
dataset = load_dataset("kmfoda/booksum", split="train").to_pandas()
143+
texts = dataset.sample(num_samples, random_state=42)["chapter"].tolist()
144+
145+
# Initialize variables
146+
position_ids = torch.arange(q_len).unsqueeze(0)
147+
scores = torch.zeros((model.config.num_hidden_layers, num_key_value_heads))
148+
149+
# Compute scores
150+
for text in texts:
151+
with torch.no_grad():
152+
# Compute hidden states
153+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
154+
hidden_states = list(model(**inputs, output_hidden_states=True).hidden_states[:-1])
155+
156+
for layer_idx, h in enumerate(hidden_states):
157+
module = model.model.layers[layer_idx]
158+
d = module.self_attn.head_dim
159+
h = module.input_layernorm(h)
160+
161+
# Mean query
162+
q = module.self_attn.q_proj(h)
163+
q = q.view(1, q.shape[1], -1, d)
164+
q = q.mean(dim=1, keepdim=True)
165+
q = q.repeat(1, q_len, 1, 1).transpose(1, 2)
166+
167+
# Mean key
168+
k = module.self_attn.k_proj(h)
169+
k = k.view(1, k.shape[1], -1, d)
170+
k = k.mean(dim=1, keepdim=True)
171+
k = k.repeat(1, q_len, 1, 1).transpose(1, 2)
172+
173+
# Apply RoPE
174+
cos, sin = model.model.rotary_emb(h, position_ids.to(h.device))
175+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
176+
k = k.repeat_interleave(num_key_value_groups, dim=1)
177+
178+
# Compute attention weights for the last token
179+
attn_weights = torch.matmul(q[:, :, -1:, :], k.transpose(2, 3)) / (d**0.5)
180+
attn_weights = attn_weights.softmax(dim=-1, dtype=torch.float32).squeeze()
181+
182+
# Compute score: area under the cumulated attention curve
183+
s = torch.cumsum(attn_weights, dim=1).mean(1)
184+
s = s.view(-1, num_key_value_groups).mean(1)
185+
186+
# Store the scores
187+
scores[layer_idx] += s.cpu() / num_samples
188+
return scores.numpy()

kvpress/presses/qfilter_press.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from functools import cache
45
from contextlib import contextmanager
56
from dataclasses import dataclass
67

@@ -32,6 +33,7 @@ def __post_init_from_model__(self, model):
3233
self.q_filters = self.q_filters.to(model.dtype)
3334

3435
@staticmethod
36+
@cache
3537
def load_q_filters(model_name):
3638
try:
3739
return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ rouge = "^1.0.1"
2626
bert-score = "^0.3.13"
2727
accelerate = "^1.0.0"
2828
requests = "^2.32.3"
29+
cachetools = "^5.5.2"
2930

3031
[tool.poetry.dev-dependencies]
3132
pytest = "^7.0.0"

0 commit comments

Comments
 (0)