Skip to content

Commit 5984ad2

Browse files
committed
Add DuoAttention on the fly
Signed-off-by: SimJeg <sjegou@nvidia.com>
1 parent 4100647 commit 5984ad2

File tree

3 files changed

+96
-14
lines changed

3 files changed

+96
-14
lines changed

evaluation/evaluate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"think": ThinKPress(),
6666
"tova": TOVAPress(),
6767
"duo_attention": DuoAttentionPress(),
68+
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
6869
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
6970
}
7071

kvpress/presses/duo_attention_press.py

Lines changed: 90 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from time import time
45
from contextlib import contextmanager
56
from dataclasses import dataclass, field
67
from io import StringIO
78

9+
810
import numpy as np
911
import requests # type: ignore[import-untyped]
1012
import torch
13+
from datasets import load_dataset
14+
from transformers import AutoTokenizer
15+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
1116

1217
from kvpress.presses.base_press import BasePress
1318

@@ -29,12 +34,15 @@ class DuoAttentionPress(BasePress):
2934
Splits attention heads into two types:
3035
- Retrieval heads: use the full KV cache
3136
- Streaming heads: use only sink and recent tokens.
32-
33-
Head classification is based on scores loaded from https://github.com/mit-han-lab/duo-attention/
3437
The higher the head_compression_ratio, the more streaming heads are used.
38+
39+
Head classification is based on scores.
40+
- If on_the_fly_scoring=False, scores are loaded from https://github.com/mit-han-lab/duo-attention/
41+
- (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly
3542
"""
3643

3744
head_compression_ratio: float = 0.0
45+
on_the_fly_scoring: bool = False
3846
compression_ratio_: float = field(init=False, default=None)
3947
recent_size: int = field(init=False, default=None)
4048
sink_size: int = field(init=False, default=None)
@@ -44,15 +52,19 @@ def __post_init_from_model__(self, model):
4452
"""
4553
Initialize sink_size, recent_size, and streaming_mask from a model
4654
"""
47-
# Load attention pattern from the DuoAttention repo
48-
self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model)
49-
50-
# Define retrieval and streaming heads through a binary mask
51-
n_pruned = round(head_scores.size * self.head_compression_ratio)
52-
self.streaming_mask = torch.zeros(head_scores.shape, dtype=bool, device=model.device)
53-
if n_pruned > 0:
54-
indices = np.argsort(head_scores, axis=None)[:n_pruned]
55-
self.streaming_mask[np.unravel_index(indices, head_scores.shape)] = True
55+
if getattr(self, "_post_init_model_name", None) != model.config.name_or_path:
56+
# Load attention pattern from the DuoAttention repo
57+
self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model)
58+
if self.on_the_fly_scoring:
59+
head_scores = duo_attention_on_the_fly(model)
60+
61+
# Define retrieval and streaming heads through a binary mask
62+
n_pruned = round(head_scores.size * self.head_compression_ratio)
63+
self.streaming_mask = torch.zeros(head_scores.shape, dtype=bool, device=model.device)
64+
if n_pruned > 0:
65+
indices = np.argsort(head_scores, axis=None)[:n_pruned]
66+
self.streaming_mask[np.unravel_index(indices, head_scores.shape)] = True
67+
self._post_init_model_name = model.config.name_or_path
5668

5769
@property
5870
def compression_ratio(self) -> float:
@@ -108,3 +120,70 @@ def __call__(self, model):
108120
self.__post_init_from_model__(model)
109121
with super().__call__(model):
110122
yield
123+
124+
125+
def duo_attention_on_the_fly(model, num_samples=50, q_len=500):
126+
"""
127+
New experimental method to quickly compute DuoAttention scores:
128+
- Compute the mean query and key on num_samples random samples from BookSum
129+
- Repeat the mean query and key q_len times and apply RoPE to get (Q, K)
130+
- Compute the attention weights for (Q[-1], K) and compute the "area under the cumulated attention curve"
131+
These scores could also be saved to avoid recomputing them but this method is still experimental
132+
"""
133+
134+
start = time()
135+
print(f"Starting computation of DuoAttention scores based on {num_samples} samples")
136+
tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
137+
num_heads = model.config.num_attention_heads
138+
num_key_value_heads = model.config.num_key_value_heads
139+
num_key_value_groups = num_heads // num_key_value_heads
140+
141+
# Load data
142+
dataset = load_dataset("kmfoda/booksum", split="train").to_pandas()
143+
texts = dataset.sample(num_samples, random_state=42)["chapter"].tolist()
144+
145+
# Initialize variables
146+
position_ids = torch.arange(q_len).unsqueeze(0)
147+
scores = torch.zeros((model.config.num_hidden_layers, num_key_value_heads))
148+
149+
# Compute scores
150+
for text in texts:
151+
with torch.no_grad():
152+
# Compute hidden states
153+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
154+
hidden_states = list(model(**inputs, output_hidden_states=True).hidden_states[:-1])
155+
156+
for layer_idx, h in enumerate(hidden_states):
157+
module = model.model.layers[layer_idx]
158+
d = module.self_attn.head_dim
159+
h = module.input_layernorm(h)
160+
161+
# Mean query
162+
q = module.self_attn.q_proj(h)
163+
q = q.view(1, q.shape[1], -1, d)
164+
q = q.mean(dim=1, keepdim=True)
165+
q = q.repeat(1, q_len, 1, 1).transpose(1, 2)
166+
167+
# Mean key
168+
k = module.self_attn.k_proj(h)
169+
k = k.view(1, k.shape[1], -1, d)
170+
k = k.mean(dim=1, keepdim=True)
171+
k = k.repeat(1, q_len, 1, 1).transpose(1, 2)
172+
173+
# Apply RoPE
174+
cos, sin = model.model.rotary_emb(h, position_ids.to(h.device))
175+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
176+
k = k.repeat_interleave(num_key_value_groups, dim=1)
177+
178+
# Compute attention weights for the last token
179+
attn_weights = torch.matmul(q[:, :, -1:, :], k.transpose(2, 3)) / (d**0.5)
180+
attn_weights = attn_weights.softmax(dim=-1, dtype=torch.float32).squeeze()
181+
182+
# Compute score: area under the cumulated attention curve
183+
s = torch.cumsum(attn_weights, dim=1).mean(1).cpu()
184+
s = s.view(-1, num_key_value_groups).mean(1)
185+
186+
# Store the scores
187+
scores[layer_idx] += s / num_samples
188+
print(f"Finished computation of DuoAttention scores in {time() - start:.2f}s")
189+
return scores.numpy()

kvpress/presses/qfilter_press.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ class QFilterPress(ScorerPress):
2727
"""
2828

2929
def __post_init_from_model__(self, model):
30-
model_name = model.config.name_or_path.split("/")[-1]
31-
self.q_filters = self.load_q_filters(model_name)
32-
self.q_filters = self.q_filters.to(model.dtype)
30+
if getattr(self, "_post_init_model_name", None) != model.config.name_or_path:
31+
model_name = model.config.name_or_path.split("/")[-1]
32+
self.q_filters = self.load_q_filters(model_name)
33+
self.q_filters = self.q_filters.to(model.dtype)
34+
self._post_init_model_name = model.config.name_or_path
3335

3436
@staticmethod
3537
def load_q_filters(model_name):

0 commit comments

Comments
 (0)