Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Some presses rely on a different logic:
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
- `SimLayerKVPress` ([source](kvpress/presses/simlayerkv_press.py), [paper](https://arxiv.org/abs/2410.13846)): identify "lazy" layers, and apply the StreamingLLM approach to them
- `DuoAttentionPress` ([source](kvpress/presses/duo_attention_press.py), [paper](https://arxiv.org/abs/2410.10819)): split heads into retrieval heads (no compression) and streaming heads (StreamingLLM approach)
- `FinchPress` (([source](kvpress/presses/finch_press.py)), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): similar to SnapKV with a dynamic window size and key value re-rotation

Finally we provide wrapper presses that can be combined with other presses:
- `AdaKVPress` ([source](kvpress/presses/adakv_press.py), [paper](https://arxiv.org/abs/2407.11550)): prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions
Expand Down
16 changes: 11 additions & 5 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TOVAPress,
QFilterPress,
PyramidKVPress,
FinchPress,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,6 +83,7 @@
"qfilter": QFilterPress(),
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
"pyramidkv": PyramidKVPress(),
"finch": FinchPress(),
}


Expand Down Expand Up @@ -154,11 +156,6 @@ def evaluate(
save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix
)

if compress_questions:
df["context"] = df["context"] + df["question"]
df["question"] = ""
save_filename = save_filename.with_name(save_filename.stem + "__compressed_questions" + save_filename.suffix)

# Load press
assert press_name in PRESS_DICT
press = PRESS_DICT[press_name]
Expand Down Expand Up @@ -198,6 +195,15 @@ def evaluate(
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
else:
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

if isinstance(press, FinchPress):
df["context"] = df["context"] + pipe.tokenizer.bos_token

if compress_questions:
df["context"] = df["context"] + df["question"]
df["question"] = ""
save_filename = save_filename.with_name(save_filename.stem + "__compressed_questions" + save_filename.suffix)

# Run pipeline on each context
df["predicted_answer"] = None
df_context = df.groupby("context")
Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.finch_press import FinchPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()
Expand Down Expand Up @@ -53,4 +54,5 @@
"ChunkKVPress",
"QFilterPress",
"PyramidKVPress",
"FinchPress",
]
6 changes: 5 additions & 1 deletion kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers.pipelines.base import GenericTensor

from kvpress.presses.base_press import BasePress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
Expand Down Expand Up @@ -179,10 +180,13 @@ def _forward(
# Greedy decoding for each question
answers = []
for question_ids in input_tensors["questions_ids"]:
if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys):
context_length = cache.get_seq_length()

answer = self.generate_answer(
question_ids=question_ids.to(self.model.device),
cache=cache,
context_length=(cache.get_seq_length() if isinstance(press, KeyRerotationPress) else context_length),
context_length=context_length,
max_new_tokens=max_new_tokens,
)
answers.append(answer)
Expand Down
134 changes: 134 additions & 0 deletions kvpress/presses/finch_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass, field
from contextlib import contextmanager

import torch
from torch.nn import functional as F

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress
from transformers.models.llama.modeling_llama import rotate_half


@dataclass
class FinchPress(BasePress):
"""
Implementation of Finch (https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)
without chunked prefilling.

Finch starts with SnapKV-style compression, but the window size is not fixed. Instead, the user must provide
a second <bos_token> between the context and the window (input = context + tokenizer.bos_token + question)

The options are also available
- normalizing scores using the number of non-zero attention weights in the window
- compressing by chunks
- rerotating keys after compression (similar to KeyRerotationPress)
"""

compression_ratio: float = 0.0
chunk_length: int = None
normalize_scores: bool = True
rerotate_keys: bool = True
window_size: int = field(default=None, init=False)

def score(self, module, hidden_states, keys, values, attentions, kwargs):
"""
Similar to SnapKVPress except it adds a normalization step before averaging on the context window.
"""

bsz, num_key_value_heads, q_len, _ = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads

if attentions is not None:
attn_weights = attentions[..., -self.window_size :, : -self.window_size]
else:
attn_weights = SnapKVPress.compute_window_attention(
module, hidden_states, keys, self.window_size, kwargs["position_embeddings"]
)

if self.normalize_scores:
non_zero_counts = torch.arange(q_len - self.window_size, q_len)[None, None, :, None]
non_zero_counts = non_zero_counts.to(attn_weights.device)
attn_weights = attn_weights * non_zero_counts

# Average per group
scores = attn_weights.mean(dim=-2)
scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size)
scores = scores.mean(dim=2)

# Add back the observation window. Use max score to make sure the window is not pruned.
scores = F.pad(scores, (0, self.window_size), value=scores.max().item())
return scores

def compress(self, module, hidden_states, keys, values, attentions, kwargs):
"""
Scores are computed by chunks, keys and values are then compressed and re-rotated.
"""

if self.compression_ratio == 0:
return keys, values
assert self.window_size is not None, "window_size must be provided"

# Compute scores
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

# Compute indices to keep (optionally by chunks)
q_len = hidden_states.shape[1]
if self.chunk_length is None:
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
else:
assert self.chunk_length > self.window_size / (1 - self.compression_ratio)
indices = []
for i in range(0, q_len, self.chunk_length):
chunk_scores = scores[:, :, i : i + self.chunk_length]
n_kept = max(1, int(chunk_scores.shape[2] * (1 - self.compression_ratio)))
chunk_indices = i + chunk_scores.topk(n_kept, dim=-1).indices
indices.append(chunk_indices)
indices = torch.cat(indices, dim=-1)

indices = torch.sort(indices, dim=2).values
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Rerotate keys
if self.rerotate_keys:
cos, sin = kwargs["position_embeddings"]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
keys = keys.gather(2, indices).contiguous()
cos, sin = cos[:, : indices.shape[2]], sin[:, : indices.shape[2]]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
else:
keys = keys.gather(2, indices).contiguous()

values = values.gather(2, indices).contiguous()

return keys, values

def embed_token_forward_hook(self, module, input, output):
"""
Forward hook to detect a second <bos_token> delimiting the context and the window
"""
if input[0][0, 0] == self.bos_token_id: # prefilling
assert len(input[0]) == 1, "Only batch size 1 is supported."
try:
context_length = int(torch.nonzero(input[0][0] == self.bos_token_id)[1].item())
self.window_size = len(input[0][0]) - 1 - context_length
assert self.window_size > 0, "No window detected (window size must be > 0)."
# Remove the second <bos_token> from the output
output = torch.cat([output[:, :context_length], output[:, context_length + 1 :]], dim=1)
except IndexError:
raise IndexError("A second <bos_token> must delimit the context and the question.")
return output

@contextmanager
def __call__(self, model):
self.bos_token_id = model.generation_config.bos_token_id
with super().__call__(model):
try:
hook = model.model.embed_tokens.register_forward_hook(self.embed_token_forward_hook)
yield
finally:
hook.remove()
1 change: 1 addition & 0 deletions kvpress/presses/key_rerotation_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def compress(
q_len = hidden_states.shape[1]
n_kept = int(q_len * (1 - self.press.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = torch.sort(indices, dim=2).values
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

cos, sin = kwargs["position_embeddings"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.2.4"
version = "0.2.5"
readme = "README.md"

[tool.poetry.dependencies]
Expand Down
18 changes: 18 additions & 0 deletions tests/presses/test_finch_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from kvpress import FinchPress
from tests.fixtures import unit_test_model # noqa: F401


def test_finch_press(unit_test_model): # noqa: F811
for press in [
FinchPress(0.5),
FinchPress(0.5, rerotate_keys=False),
FinchPress(0.5, normalize_scores=False),
FinchPress(0.2, chunk_length=5),
]:
with press(unit_test_model):
bos = unit_test_model.generation_config.bos_token_id
input_ids = torch.arange(10, 20)
input_ids[0] = bos
input_ids[8] = bos
unit_test_model(input_ids.unsqueeze(0))
1 change: 1 addition & 0 deletions tests/presses/test_key_rerotation_press_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def score(
q_len = hidden_states.shape[1]
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = torch.sort(indices, dim=2).values
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
self.indices = indices

Expand Down