Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
73dcb45
first commit
FaureElia Mar 12, 2025
81b53f9
created setup
FaureElia Mar 12, 2025
8f87626
first basic implementation of FinchPress
FaureElia Mar 13, 2025
1db9786
adjusted FinchPress
FaureElia Mar 14, 2025
98bc55a
introduce new parameter condition_len (still have to modify compute_f…
FaureElia Mar 15, 2025
008e7e9
Adapted compute_finch_attention to condition_len. TODO: Normalization…
miriam-16 Mar 17, 2025
5f2d089
fix question len, adjusted the question ids for generation.
giulio98 Mar 15, 2025
c512283
fix: parameter condition_len in function
miriam-16 Mar 17, 2025
87ba07e
add nomralization
FaureElia Mar 17, 2025
fa0e28d
clean FinchPress
FaureElia Mar 18, 2025
c7dff0e
normalization
giulio98 Mar 18, 2025
8970c97
fix len question ids
giulio98 Mar 18, 2025
84ce456
add self
giulio98 Mar 18, 2025
f0efe61
chunked forward -- not working(gibberish output), check shapes and le…
giulio98 Mar 18, 2025
15db060
Fixed attention mask and binary mask in compute_finch_attention
FaureElia Mar 22, 2025
26c64b0
updated gitignore
FaureElia Mar 22, 2025
a6ebc11
n_kept value at last iteration (TODO: further checks)
miriam-16 Mar 22, 2025
e2f3dd6
fixed pipeline for other presses and finished adjusting number of tok…
FaureElia Mar 23, 2025
54c991b
Fix: n_kept at last iteration considers condition_len
miriam-16 Mar 23, 2025
91e9f05
adjusted n_kept, at final chunk keeps 90% of context+full question
FaureElia Mar 23, 2025
c48ea98
added docstring and general code cleaning
miriam-16 Mar 24, 2025
7ec514a
reorder indices, add normalization parameter, add sink tokens
giulio98 Mar 25, 2025
657a27d
partially fixed finch press
FaureElia Mar 26, 2025
d0bb5d0
fixed finch press, aligned to original code
FaureElia Mar 28, 2025
fdde74d
preparing for pr
giulio98 Apr 4, 2025
c1acb0d
add ChunkKV
Dominic789654 Mar 5, 2025
8533beb
Update copyright date (#60)
SimJeg Mar 13, 2025
e71f4e9
Add QFilterPress (#54)
NathanGodey Mar 17, 2025
4f917cc
Add longbench benchmark
Xnhyacinth Mar 19, 2025
8d95b14
Add DuoAttention on the fly (#63)
SimJeg Mar 19, 2025
cace3a5
Resolve stash conflict in tests/default_presses.py
giulio98 Apr 4, 2025
a833ae3
Fix: remove conflict markers after saving changes
giulio98 Apr 4, 2025
39417ad
test passes
giulio98 Apr 4, 2025
2cdd69c
add readme
giulio98 Apr 4, 2025
0e4d4c8
commit
giulio98 Apr 4, 2025
ea42501
Merge branch 'main' into features/finch_press
giulio98 Apr 4, 2025
c570844
make style
giulio98 Apr 4, 2025
9afd7f8
remove leftovers
giulio98 Apr 4, 2025
e8d84fa
make style check pass
giulio98 Apr 6, 2025
689c9c5
fix: readme
giulio98 Apr 6, 2025
f628ca4
unmodify qfilter_press.py duo_attention_press.py and longbech/calcula…
giulio98 Apr 14, 2025
c5a9ca7
typo duo attention press import
giulio98 Apr 14, 2025
f8034e5
use SnapKV window attention for finch
giulio98 Apr 14, 2025
f830665
update past_key_values using last_output from original forward
giulio98 Apr 14, 2025
3f2ed03
add import snapkv
giulio98 Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
- `TOVAPress` ([source](kvpress/presses/tova_press.py), [paper](https://arxiv.org/abs/2401.06104)): attention weight of the last query averaged across heads
- `ObservedAttentionPress` ([source](kvpress/presses/observed_attention_press.py), [paper](https://arxiv.org/abs/2306.14048)): average attention weight observed during in pre-filling phase
- `QFilterPress` ([source](kvpress/presses/qfilter_press.py), [paper](https://arxiv.org/abs/2503.02812)): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.
- `FinchPress` (([source](kvpress/presses/finch_press.py)), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): average attention weight of the prompt.

Some presses rely on a different logic:
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
Expand All @@ -82,7 +83,6 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.


For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

## Evaluation
Expand Down
4 changes: 3 additions & 1 deletion evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
CriticalKVPress,
DuoAttentionPress,
ExpectedAttentionPress,
FinchPress,
KnormPress,
ObservedAttentionPress,
QFilterPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
QFilterPress,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,6 +77,7 @@
"think": ThinKPress(),
"tova": TOVAPress(),
"duo_attention": DuoAttentionPress(),
"finch": FinchPress(),
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
"qfilter": QFilterPress(),
Expand Down
4 changes: 3 additions & 1 deletion kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.qfilter_press import QFilterPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()
Expand All @@ -49,6 +50,7 @@
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"FinchPress",
"ChunkKVPress",
"QFilterPress",
]
13 changes: 12 additions & 1 deletion kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers.pipelines.base import GenericTensor

from kvpress.presses.base_press import BasePress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
Expand Down Expand Up @@ -161,6 +162,14 @@ def _forward(
context_ids = input_tensors["context_ids"].to(self.model.device)
context_length = context_ids.shape[1]

if isinstance(press, FinchPress) or isinstance(getattr(press, "press", None), FinchPress):
# finch press cannot be done with multiple questions
assert len(input_tensors["questions_ids"]) == 1, "Finch press cannot be done with multiple questions"
question_ids = input_tensors["questions_ids"][0].to(self.model.device)
context_ids = torch.cat((context_ids, question_ids[:, :-1]), dim=1)
press.condition_len = len(question_ids[:, :-1][0])
input_tensors["questions_ids"][0] = question_ids[:, -1:]

# Prefilling using the press on the context
if cache is None:
cache = DynamicCache()
Expand All @@ -182,7 +191,9 @@ def _forward(
answer = self.generate_answer(
question_ids=question_ids.to(self.model.device),
cache=cache,
context_length=(cache.get_seq_length() if isinstance(press, KeyRerotationPress) else context_length),
context_length=(
cache.get_seq_length() if isinstance(press, (KeyRerotationPress, FinchPress)) else context_length
),
max_new_tokens=max_new_tokens,
)
answers.append(answer)
Expand Down
256 changes: 256 additions & 0 deletions kvpress/presses/finch_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import math
from contextlib import contextmanager
from dataclasses import dataclass

import torch
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel, QuantizedCache
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress


@dataclass
class FinchPress(BasePress):
"""
Finch uses the attention information between the prompt and the document chunk to dynamically
identify the most relevant KV pairs across different layers.
This information then is stored in the KV cache for the processing of the next input chunk
(https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)

Parameters
----------
condition_len : int
The number of tokens in the prompt.
split_size : int
The number of chunks to split the context into.
"""

compression_ratio: float = 0.0
split_size: int = 1
normalize_scores: bool = True
condition_len: int = None # calculate on length of question dynamically

def score(self, module, hidden_states, keys, values, attentions, kwargs):

bsz, num_key_value_heads, q_len, _ = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads

if attentions is not None:
attn_weights = attentions[..., -self.condition_len:, : -self.condition_len]
else:
attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.condition_len, kwargs["position_embeddings"])
if self.normalize_scores:
non_zero_counts = torch.arange(q_len - self.condition_len, q_len)[None, None, :, None]
non_zero_counts = non_zero_counts.to(attn_weights.device)
attn_weights = attn_weights * non_zero_counts

# Average per group
scores = attn_weights.mean(dim=-2)
scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.condition_len)
scores = scores.mean(dim=2)

# Add back the observation window. Use max score to make sure the window is not pruned.
scores = F.pad(scores, (0, self.condition_len), value=scores.max().item())
return scores

@staticmethod
def _rerotate_cos_sin(x, inv_freq, important_pos_batch):
B, H, L = important_pos_batch.shape
device = important_pos_batch.device
device_type = x.device.type
dtype = x.dtype
idx = torch.arange(0, L, device=device)
idx = idx.unsqueeze(0)
inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1)
idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L)
delta_pos = idx - important_pos_batch
delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L)

device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"

with torch.autocast(device_type=device_type, enabled=False):
freqs = delta_pos.float() * inv_freq.float()
freqs = freqs.transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().contiguous()
sin = emb.sin().contiguous()
return cos.to(dtype=dtype), sin.to(dtype=dtype)

def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.compression_ratio == 0:
return keys, values

# Compute scores from base press
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

context_length = kwargs["context_length"]
last_iteration = kwargs["split_idx"] == self.split_size - 1
q_len = hidden_states.shape[1]

if last_iteration:
n_kept_context = int(context_length * (1 - self.compression_ratio))
else:
past_cache_len = scores.shape[-1] - q_len
n_kept_context = int((q_len - self.condition_len) * (1 - self.compression_ratio)) + past_cache_len
Comment on lines +106 to +109
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not correct, for instance for last_iteration should be

n_kept_context = int(context_length * (1 - self.compression_ratio) - self.condition_len * self.compression_ratio)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition len shouldn't be compressed isn't it? Because in other presses the question is provided as is so if we compress also the question then the comparison will not be fair

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are using the question in the input so it should be compressed too, as it's done in other presses. Finch can use the information of the window size (which is a bit unfair) to exclude the question from the compression, but it should translate to a slightly higher compression for the context.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that the difference is very small (~0.5% in CR) so I don't think it should impact much performance

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are using the question in the input so it should be compressed too, as it's done in other presses. Finch can use the information of the window size (which is a bit unfair) to exclude the question from the compression, but it should translate to a slightly higher compression for the context.

Query-aware compression like finch or snapkkv and question agnostic comes with trade off. Query-aware compression gives highest performance but we have to rerun the compression for each new query, instead query agnostic like other presses can be done independently on the question giving more throughtput, this is a topic we explored in our recent paper: https://arxiv.org/pdf/2503.04973 where we propose a middle ground and compress using just task and few shot examples. That being said the final budget of tokens should be equal for all the approaches for a fair comparison IMO. What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we have to rerun the compression for each new query

Approaches like SnapKV or Finch are close to sparse attention (i.e. retrieve the right KV pairs from the full KV cache) and somehow a bit different from compression, because as you mention you have to re-run them for each query.

That being said the final budget of tokens should be equal for all the approaches for a fair comparison IMO

I agree that's why I proposed this update. Could you review #69 and comment ? You can also open a new PR with the same code if you want to appear in the contributors of the repo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be possible to add Co-authored by in your commits adding myself @miriam-16 and @FaureElia ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, could you please close this branch ?


indices_context = scores[:, :, : -self.condition_len].topk(n_kept_context, dim=-1).indices
indices_condition = torch.arange(scores.shape[-1] - self.condition_len, scores.shape[-1], device=scores.device)[
None, None, :
].expand(scores.shape[0], scores.shape[1], -1)
indices_context, _ = torch.sort(indices_context, dim=-1)

# concatenate the indices
indices = torch.cat([indices_context, indices_condition], dim=-1)

# rerotate the positions
new_cos, new_sin = self._rerotate_cos_sin(keys, module.rotary_emb.inv_freq, indices)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
keys = keys.gather(2, indices).contiguous()
keys = (keys * new_cos) + (rotate_half(keys) * new_sin)
values = values.gather(2, indices).contiguous()
return keys, values

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any update here compared to BasePress ? if not remove

"""
Default forward hook called after the forward pass of an attention layer.
The hook calls the compress method to compress the KV cache while ensuring:
- compression is only applied only during the pre-filling phase
- KV cache quantization is handled correctly

Parameters
----------
module :
Transformer attention layer.
input :
Input to the hook. This is the input to the forward pass of the layer.
kwargs :
Keyword arguments, as given to the forward pass of the layer.
output :
Output of the hook. This is the original output of the forward pass of the layer.

Returns
-------
Modified output of the forward pass of the layer.
"""
hidden_states = kwargs["hidden_states"]
cache = kwargs["past_key_value"]
q_len = hidden_states.shape[1]

# Don't compress after pre-filling
if kwargs["cache_position"][-1] > q_len + cache.get_seq_length():
return output

if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs)

if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device)
cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device)
cache._seen_tokens = keys.shape[2]
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values
return output

@contextmanager
def __call__(self, model: PreTrainedModel):
hooks = []
try:
for layer in model.model.layers:
layer.self_attn.rotary_emb = model.model.rotary_emb
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))
original_forward = model.forward

def chunked_forward(*args, **kwargs):
args = list(args)
kwargs["input_ids"] = kwargs.get("input_ids", args.pop(0) if args else None)
kwargs["attention_mask"] = kwargs.get("attention_mask", args.pop(0) if args else None)
args = tuple(args)

input_ids = kwargs["input_ids"]
attention_mask = kwargs.get("attention_mask")

# Split input_ids into context and question tokens.
context_ids = input_ids[:, : -self.condition_len]
question_ids = input_ids[:, -self.condition_len :]

if attention_mask is not None:
context_attention_mask = attention_mask[:, : -self.condition_len]
question_attention_mask = attention_mask[:, -self.condition_len :]

# Calculate the total number of context tokens.
context_length = context_ids.shape[1]
kwargs["context_length"] = context_length

# Determine the chunk size so that we split context_ids into exactly split_size chunks.
chunk_size = context_length // self.split_size
last_output = None

for i in range(self.split_size):
kwargs["split_idx"] = i
start = i * chunk_size
# For the last chunk, include any remaining tokens.
end = start + chunk_size if i < self.split_size - 1 else context_length

# Get the current chunk from context_ids and combine with the question tokens.
context_chunk = context_ids[:, start:end]
kwargs["input_ids"] = torch.cat([context_chunk, question_ids], dim=1)

if attention_mask is not None:
context_attention_mask_chunk = context_attention_mask[:, start:end]
kwargs["attention_mask"] = torch.cat(
[context_attention_mask_chunk, question_attention_mask], dim=1
)

last_output = original_forward(use_cache=True, *args, **kwargs)

# Only adjust the past key/values caches if it's not the last iteration.
if i < self.split_size - 1:
for layer_idx, _ in enumerate(model.model.layers):
# Adjust the past key/values caches to remove the question tokens for the next iteration
last_output.past_key_values.key_cache[layer_idx] = (
last_output.past_key_values
.key_cache[layer_idx][:, :, : -self.condition_len, :]
.contiguous()
)
last_output.past_key_values.value_cache[layer_idx] = (
last_output.past_key_values
.value_cache[layer_idx][:, :, : -self.condition_len, :]
.contiguous()
)
last_output.past_key_values._seen_tokens = last_output.past_key_values.get_seq_length()
kwargs["past_key_values"] = last_output.past_key_values

return last_output

# Override the model's forward with the chunked version
model.forward = chunked_forward
yield
finally:
# Remove all hooks and restore the original forward method
for hook in hooks:
hook.remove()
model.forward = original_forward
10 changes: 9 additions & 1 deletion tests/default_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from kvpress import (
DuoAttentionPress,
ExpectedAttentionPress,
FinchPress,
KnormPress,
QFilterPress,
RandomPress,
SimLayerKVPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
QFilterPress,
)


Expand Down Expand Up @@ -52,4 +53,11 @@ def load_attention_pattern(model):
{"lazy_threshold": 0.2, "n_initial": 1, "n_recent": 1, "n_last": 1},
],
},
{
"cls": FinchPress,
"kwargs": [
{"compression_ratio": 0.2, "condition_len": 2},
{"compression_ratio": 0.8, "condition_len": 2},
],
},
]