Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
StreamingLLMPress,
ThinKPress,
TOVAPress,
QFilterPress,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,9 +76,10 @@
"think": ThinKPress(),
"tova": TOVAPress(),
"duo_attention": DuoAttentionPress(),
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
"qfilter": QFilterPress(),
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
"full_kv": ExpectedAttentionPress(0.0),
}


Expand Down
84 changes: 81 additions & 3 deletions kvpress/presses/duo_attention_press.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from cachetools import cached, LRUCache # type: ignore[import-untyped]
from contextlib import contextmanager
from dataclasses import dataclass, field
from io import StringIO

import numpy as np
import requests # type: ignore[import-untyped]
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from kvpress.presses.base_press import BasePress

Expand All @@ -20,6 +24,8 @@
"mistralai/Mistral-7B-Instruct-v0.3": "Mistral-7B-Instruct-v0.3/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
}

cache = LRUCache(maxsize=128)


@dataclass
class DuoAttentionPress(BasePress):
Expand All @@ -29,12 +35,15 @@ class DuoAttentionPress(BasePress):
Splits attention heads into two types:
- Retrieval heads: use the full KV cache
- Streaming heads: use only sink and recent tokens.

Head classification is based on scores loaded from https://github.com/mit-han-lab/duo-attention/
The higher the head_compression_ratio, the more streaming heads are used.

Head classification is based on scores.
- If on_the_fly_scoring=False, scores are loaded from https://github.com/mit-han-lab/duo-attention/
- (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly
"""

head_compression_ratio: float = 0.0
on_the_fly_scoring: bool = False
compression_ratio_: float = field(init=False, default=None)
recent_size: int = field(init=False, default=None)
sink_size: int = field(init=False, default=None)
Expand All @@ -45,7 +54,10 @@ def __post_init_from_model__(self, model):
Initialize sink_size, recent_size, and streaming_mask from a model
"""
# Load attention pattern from the DuoAttention repo
self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model)
if self.on_the_fly_scoring:
self.sink_size, self.recent_size, head_scores = 128, 256, duo_attention_on_the_fly(model)
else:
self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model)

# Define retrieval and streaming heads through a binary mask
n_pruned = round(head_scores.size * self.head_compression_ratio)
Expand Down Expand Up @@ -82,6 +94,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
return keys, values

@staticmethod
@cached(cache, key=lambda model: model.config.name_or_path)
def load_attention_pattern(model):
"""
Load the attention pattern from the DuoAttention repo
Expand All @@ -108,3 +121,68 @@ def __call__(self, model):
self.__post_init_from_model__(model)
with super().__call__(model):
yield


@cached(cache, key=lambda model, num_samples=50, q_len=500: (model.config.name_or_path, num_samples, q_len))
def duo_attention_on_the_fly(model, num_samples=50, q_len=500):
"""
New experimental method to quickly compute DuoAttention scores:
- Compute the mean query and key on num_samples random samples from BookSum
- Repeat the mean query and key q_len times and apply RoPE to get (Q, K)
- Compute the attention weights for (Q[-1], K) and compute the "area under the cumulated attention curve"
These scores could also be saved to avoid recomputing them but this method is still experimental
"""

tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
num_heads = model.config.num_attention_heads
num_key_value_heads = model.config.num_key_value_heads
num_key_value_groups = num_heads // num_key_value_heads

# Load data
dataset = load_dataset("kmfoda/booksum", split="train").to_pandas()
texts = dataset.sample(num_samples, random_state=42)["chapter"].tolist()

# Initialize variables
position_ids = torch.arange(q_len).unsqueeze(0)
scores = torch.zeros((model.config.num_hidden_layers, num_key_value_heads))

# Compute scores
for text in texts:
with torch.no_grad():
# Compute hidden states
inputs = tokenizer(text, return_tensors="pt").to(model.device)
hidden_states = list(model(**inputs, output_hidden_states=True).hidden_states[:-1])

for layer_idx, h in enumerate(hidden_states):
module = model.model.layers[layer_idx]
d = module.self_attn.head_dim
h = module.input_layernorm(h)

# Mean query
q = module.self_attn.q_proj(h)
q = q.view(1, q.shape[1], -1, d)
q = q.mean(dim=1, keepdim=True)
q = q.repeat(1, q_len, 1, 1).transpose(1, 2)

# Mean key
k = module.self_attn.k_proj(h)
k = k.view(1, k.shape[1], -1, d)
k = k.mean(dim=1, keepdim=True)
k = k.repeat(1, q_len, 1, 1).transpose(1, 2)

# Apply RoPE
cos, sin = model.model.rotary_emb(h, position_ids.to(h.device))
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = k.repeat_interleave(num_key_value_groups, dim=1)

# Compute attention weights for the last token
attn_weights = torch.matmul(q[:, :, -1:, :], k.transpose(2, 3)) / (d**0.5)
attn_weights = attn_weights.softmax(dim=-1, dtype=torch.float32).squeeze()

# Compute score: area under the cumulated attention curve
s = torch.cumsum(attn_weights, dim=1).mean(1)
s = s.view(-1, num_key_value_groups).mean(1)

# Store the scores
scores[layer_idx] += s.cpu() / num_samples
return scores.numpy()
2 changes: 2 additions & 0 deletions kvpress/presses/qfilter_press.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from functools import cache
from contextlib import contextmanager
from dataclasses import dataclass

Expand Down Expand Up @@ -32,6 +33,7 @@ def __post_init_from_model__(self, model):
self.q_filters = self.q_filters.to(model.dtype)

@staticmethod
@cache
def load_q_filters(model_name):
try:
return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ rouge = "^1.0.1"
bert-score = "^0.3.13"
accelerate = "^1.0.0"
requests = "^2.32.3"
cachetools = "^5.5.2"

[tool.poetry.dev-dependencies]
pytest = "^7.0.0"
Expand Down