Skip to content

Commit a94a78d

Browse files
authored
Add DuoAttentionPress (#50)
* Add DuoAttentionPress * Fix tests and compression_ratio * Address feedback * Update plot * Update version
1 parent f98de1f commit a94a78d

File tree

8 files changed

+145
-3
lines changed

8 files changed

+145
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
7070
Some presses rely on a different logic:
7171
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
7272
- `SimLayerKVPress` ([source](kvpress/presses/simlayerkv_press.py), [paper](https://arxiv.org/abs/2410.13846)): identify "lazy" layers, and apply the StreamingLLM approach to them
73+
- `DuoAttentionPress` ([source](kvpress/presses/duo_attention_press.py), [paper](https://arxiv.org/abs/2410.10819)): split heads into retrieval heads (no compression) and streaming heads (StreamingLLM approach)
7374

7475
Finally we provide wrapper presses that can be combined with other presses:
7576
- `AdaKVPress` ([source](kvpress/presses/adakv_press.py), [paper](https://arxiv.org/abs/2407.11550)): prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions
1.91 KB
Loading

evaluation/evaluate.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
StreamingLLMPress,
2929
ThinKPress,
3030
TOVAPress,
31+
DuoAttentionPress,
3132
)
3233

3334
logger = logging.getLogger(__name__)
@@ -62,6 +63,7 @@
6263
"streaming_llm": StreamingLLMPress(),
6364
"think": ThinKPress(),
6465
"tova": TOVAPress(),
66+
"duo_attention": DuoAttentionPress(),
6567
}
6668

6769

@@ -139,7 +141,11 @@ def evaluate(
139141
# Load press
140142
assert press_name in PRESS_DICT
141143
press = PRESS_DICT[press_name]
142-
press.compression_ratio = compression_ratio # type:ignore[attr-defined]
144+
145+
if isinstance(press, (DuoAttentionPress)):
146+
press.head_compression_ratio = compression_ratio
147+
else:
148+
press.compression_ratio = compression_ratio # type:ignore[attr-defined]
143149

144150
# Initialize pipeline with the correct attention implementation
145151
model_kwargs = {"torch_dtype": "auto"}
@@ -176,16 +182,18 @@ def evaluate(
176182
max_context_length=max_context_length,
177183
)
178184
df.loc[df_.index, "predicted_answer"] = output["answers"]
185+
df.loc[df_.index, "compression_ratio"] = press.compression_ratio # type:ignore[attr-defined]
179186
torch.cuda.empty_cache()
180187

181188
# Save answers
182-
df["predicted_answer"].to_csv(str(save_filename), index=False)
189+
df[["predicted_answer", "compression_ratio"]].to_csv(str(save_filename), index=False)
183190

184191
# Calculate metrics
185192
scorer = SCORER_DICT[dataset]
186193
metrics = scorer(df)
187194
with open(str(save_filename).replace(".csv", ".json"), "w") as f:
188195
json.dump(metrics, f)
196+
print(f"Average compression ratio: {df['compression_ratio'].mean():.2f}")
189197
print(metrics)
190198

191199

kvpress/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from kvpress.presses.think_press import ThinKPress
2222
from kvpress.presses.tova_press import TOVAPress
2323
from kvpress.presses.criticalkv_press import CriticalKVPress, CriticalAdaKVPress
24+
from kvpress.presses.duo_attention_press import DuoAttentionPress
25+
2426
# Patch the attention functions to support head-wise compression
2527
patch_attention_functions()
2628

@@ -44,4 +46,5 @@
4446
"PerLayerCompressionPress",
4547
"KeyRerotationPress",
4648
"ChunkPress",
49+
"DuoAttentionPress",
4750
]
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from io import StringIO
5+
from dataclasses import dataclass, field
6+
from contextlib import contextmanager
7+
8+
import torch
9+
import requests # type: ignore[import-untyped]
10+
import numpy as np
11+
12+
from kvpress.presses.base_press import BasePress
13+
14+
15+
PATTERNS_DICT = {
16+
"togethercomputer/Llama-2-7B-32K-Instruct": "Llama-2-7B-32K-Instruct/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
17+
"gradientai//Llama-3-8B-Instruct-Gradient-1048k": "Llama-3-8B-Instruct-Gradient-1048k/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
18+
"gradientai//Llama-3-8B-Instruct-Gradient-4194k": "Llama-3-8B-Instruct-Gradient-4194k/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
19+
"meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct/lr=0.02-reg=0.05-ctx=1000_128000-multi_passkey10", # noqa: E501
20+
"mistralai/Mistral-7B-Instruct-v0.2": "Mistral-7B-Instruct-v0.2/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
21+
"mistralai/Mistral-7B-Instruct-v0.3": "Mistral-7B-Instruct-v0.3/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
22+
}
23+
24+
25+
@dataclass
26+
class DuoAttentionPress(BasePress):
27+
"""
28+
Implements DuoAttention (https://arxiv.org/abs/2410.10819)
29+
30+
Splits attention heads into two types:
31+
- Retrieval heads: use the full KV cache
32+
- Streaming heads: use only sink and recent tokens.
33+
34+
Head classification is based on scores loaded from https://github.com/mit-han-lab/duo-attention/
35+
The higher the head_compression_ratio, the more streaming heads are used.
36+
"""
37+
38+
head_compression_ratio: float = 0.0
39+
compression_ratio_: float = field(init=False, default=None)
40+
recent_size: int = field(init=False, default=None)
41+
sink_size: int = field(init=False, default=None)
42+
streaming_mask: torch.Tensor = field(init=False, default=None)
43+
44+
def __post_init_from_model__(self, model):
45+
"""
46+
Initialize sink_size, recent_size, and streaming_mask from a model
47+
"""
48+
# Load attention pattern from the DuoAttention repo
49+
self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model)
50+
51+
# Define retrieval and streaming heads through a binary mask
52+
n_pruned = round(head_scores.size * self.head_compression_ratio)
53+
self.streaming_mask = torch.zeros(head_scores.shape, dtype=bool, device=model.device)
54+
if n_pruned > 0:
55+
indices = np.argsort(head_scores, axis=None)[:n_pruned]
56+
self.streaming_mask[np.unravel_index(indices, head_scores.shape)] = True
57+
58+
@property
59+
def compression_ratio(self) -> float:
60+
assert self.compression_ratio_ is not None, "Forward pass must be run to compute the compression ratio"
61+
return self.compression_ratio_
62+
63+
@compression_ratio.setter
64+
def compression_ratio(self, value):
65+
raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}")
66+
67+
def compress(self, module, hidden_states, keys, values, attentions, kwargs):
68+
69+
assert module.config._attn_implementation != "eager", "eager mode not supported"
70+
q_len = hidden_states.shape[1]
71+
72+
if (self.head_compression_ratio > 0) or (q_len > (self.sink_size + self.recent_size)):
73+
74+
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
75+
masked_keys = torch.zeros_like(keys[..., 0], dtype=torch.bool)
76+
masked_keys[:, self.streaming_mask[module.layer_idx], self.sink_size : -self.recent_size] = True
77+
module.masked_key_indices = torch.nonzero(masked_keys, as_tuple=True)
78+
79+
# Compute the compression ratio
80+
self.compression_ratio_ = self.streaming_mask.float().mean().item()
81+
self.compression_ratio_ *= 1 - (self.sink_size + self.recent_size) / q_len
82+
83+
return keys, values
84+
85+
@staticmethod
86+
def load_attention_pattern(model):
87+
"""
88+
Load the attention pattern from the DuoAttention repo
89+
"""
90+
91+
assert (
92+
model.config.name_or_path in PATTERNS_DICT
93+
), f"Checkpoint {model.config.name_or_path} not in {list(PATTERNS_DICT.keys())}"
94+
base_url = "https://raw.githubusercontent.com/mit-han-lab/duo-attention/refs/heads/main/attn_patterns"
95+
url = f"{base_url}/{PATTERNS_DICT[model.config.name_or_path]}/"
96+
97+
# Load config
98+
config = requests.get(url + "config.json").json()
99+
100+
# Load head scores and clip as in duo_attn.utils.load_attn_pattern
101+
text = requests.get(url + "full_attention_heads.tsv").text
102+
head_scores = np.loadtxt(StringIO(text), dtype=float, delimiter="\t")
103+
head_scores = np.clip(head_scores, 0, 1)
104+
105+
return config["sink_size"], config["recent_size"], head_scores
106+
107+
@contextmanager
108+
def __call__(self, model):
109+
self.__post_init_from_model__(model)
110+
with super().__call__(model):
111+
yield

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "kvpress"
33
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
44
description = "Efficiently compress the KV cache of any pretrained transformer"
5-
version = "0.2.2"
5+
version = "0.2.3"
66
readme = "README.md"
77

88
[tool.poetry.dependencies]
@@ -25,6 +25,7 @@ pandas = "^2.2.2"
2525
rouge = "^1.0.1"
2626
bert-score = "^0.3.13"
2727
accelerate = "^1.0.0"
28+
requests = "^2.32.3"
2829

2930
[tool.poetry.dev-dependencies]
3031
pytest = "^7.0.0"

tests/default_presses.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import numpy as np
45

56
from kvpress import (
67
ExpectedAttentionPress,
@@ -11,11 +12,21 @@
1112
StreamingLLMPress,
1213
ThinKPress,
1314
TOVAPress,
15+
DuoAttentionPress,
1416
)
1517

18+
19+
class TestDuoAttentionPress(DuoAttentionPress):
20+
@staticmethod
21+
def load_attention_pattern(model):
22+
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_key_value_heads
23+
return 2, 2, np.random.rand(n_layers, n_heads)
24+
25+
1626
# contains all presses to be tested
1727
# kwargs should be ordered easy to hard compression
1828
default_presses = [
29+
{"cls": TestDuoAttentionPress, "kwargs": [{"head_compression_ratio": 0.2}, {"head_compression_ratio": 0.8}]},
1930
{"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
2031
{"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
2132
{"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from kvpress.presses.duo_attention_press import DuoAttentionPress, PATTERNS_DICT
2+
3+
4+
def test_load_attention_pattern():
5+
for model_name in PATTERNS_DICT:
6+
model = type("model", (), {"config": type("config", (), {"name_or_path": model_name})})()
7+
DuoAttentionPress.load_attention_pattern(model)

0 commit comments

Comments
 (0)