Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ Description of your PR. Fixes # (issue) (if applicable)

## Checklist

- Tests are working (make test)
- Code is formatted correctly (make style, on errors try fix with make format)
- Tests are working (`make test`)
- Code is formatted correctly (`make style`, on errors try fix with `make format`)
- Copyright header is included
- [ ] All commits are signed-off using `git commit -s`
- [ ] (new press) `mypress_press.py` is in the `presses` directory
- [ ] (new press) `MyPress` is in `__init__.py`
- [ ] (new press) `README.md` is updated with a 1 liner about the new press in the Available presses section
- [ ] (new press) new press is in the `default_presses` list in `tests/default_presses.py`
- [ ] (new press) New press is in the `default_presses` list in `tests/default_presses.py`
- [ ] (new press) A docstring is provided that follows the same structure as the existing ones
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ChunkKVPress` ([source](kvpress/presses/chunkkv_press.py), [paper](https://arxiv.org/abs/2502.00299)): compresses by selecting important chunks, preserving semantic coherence
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.
- `BlockPress` ([source](kvpress/presses/keydiff_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down Expand Up @@ -134,7 +134,7 @@ By default, the `DynamicCache` is used (no quantization).
### Which models are supported ?
</summary>

Some presses depend on the model architecture (_e.g._ `ExpectedAttentionPress` or `SnapKVPress`) hence they might not work with all models. We tested support for `LlamaForCausalLM`, `MistralForCausalLM`, `Phi3ForCausalLM` and `Qwen2ForCausalLM` but many other models might be supported out of the box because their implementation is often similar in transformers.
Some presses depend on the model architecture (_e.g._ `ExpectedAttentionPress` or `SnapKVPress`) hence they might not work with all models. We tested support for `LlamaForCausalLM`, `MistralForCausalLM`, `Phi3ForCausalLM`, `Qwen2ForCausalLM`, `Qwen3ForCausalLM`, and `Gemma3ForCausalLM` but many other models might be supported out of the box because their implementation is often similar in transformers.
</details>

<details><summary>
Expand Down
4 changes: 2 additions & 2 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@

from kvpress import (
AdaKVPress,
BlockPress,
ChunkKVPress,
ComposedPress,
CriticalAdaKVPress,
CriticalKVPress,
DuoAttentionPress,
ExpectedAttentionPress,
FinchPress,
KeyDiffPress,
KnormPress,
ObservedAttentionPress,
PyramidKVPress,
Expand All @@ -37,8 +39,6 @@
StreamingLLMPress,
ThinKPress,
TOVAPress,
BlockPress,
KeyDiffPress,
)

logger = logging.getLogger(__name__)
Expand Down
5 changes: 3 additions & 2 deletions evaluation/longbench/calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import string
from collections import Counter

import numpy as np
from rouge import Rouge

Expand Down Expand Up @@ -37,7 +38,7 @@ def calculate_metrics_e(df):

def scorer_e(dataset, predictions, answers, lengths, all_classes):
scores = {"0-4k": [], "4-8k": [], "8k+": []} # type:ignore[var-annotated]
for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
for prediction, ground_truths, length in zip(predictions, answers, lengths):
score = 0.0
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip("\n").split("\n")[0]
Expand All @@ -56,7 +57,7 @@ def scorer_e(dataset, predictions, answers, lengths, all_classes):

def scorer(dataset, predictions, answers, all_classes):
total_score = 0.0
for (prediction, ground_truths) in zip(predictions, answers):
for prediction, ground_truths in zip(predictions, answers):
score = 0.0
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip("\n").split("\n")[0]
Expand Down
15 changes: 7 additions & 8 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,30 @@
from kvpress.attention_patch import patch_attention_functions
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.presses.block_press import BlockPress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.chunkkv_press import ChunkKVPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.keydiff_press import KeyDiffPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.lagkv_press import LagKVPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.lagkv_press import LagKVPress
from kvpress.presses.base_press import SUPPORTED_MODELS
from kvpress.presses.block_press import BlockPress
from kvpress.presses.keydiff_press import KeyDiffPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()
Expand Down
49 changes: 47 additions & 2 deletions kvpress/attention_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@ def search_hyperplane(X, max_iter: int = 1000):
Given a tensor X of shape (bsz, seq_len, head_dim), search for an hyperplane Y (bsz, head_dim)
such that for every i, <X[:, i], Y> <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0
Raises a ValueError if no such hyperplane is found

Parameters
----------
X : torch.Tensor
Query tensor with shape (batch_size, seq_len, head_dim) representing
the query vectors for which we want to find a nullifying hyperplane.
max_iter : int, default=1000
Maximum number of iterations to search for the hyperplane. If no valid
hyperplane is found within this limit, a ValueError is raised.

Returns
-------
torch.Tensor
Hyperplane tensor with shape (batch_size, head_dim) scaled by -1e5 / ||Y||²
to ensure that exp(<X, Y>) ≈ 0 for all queries in X.

Raises
------
ValueError
If no valid hyperplane is found within max_iter iterations.
"""
Y = X.mean(1) # this initialization is enough for most cases
for _ in range(max_iter):
Expand All @@ -25,6 +45,17 @@ def attention_patch(func):
Decorator to udpate the keys before the attention computation at the indices provided in module.masked_key_indices
The keys are updated with a fake key k such that exp(<q, k>) = 0 to fake head-wise compression
This solution is not optimal as it does not reduce peak memory and slightly increase runtime

Parameters
----------
func : callable
The original attention function to be patched. Should accept parameters
(module, query, key, value, attention_mask, dropout, **kwargs).

Returns
-------
callable
The wrapped attention function that supports head-wise key masking.
"""

def wrapper(module, query, key, value, attention_mask, dropout, **kwargs):
Expand Down Expand Up @@ -54,8 +85,22 @@ def wrapper(module, query, key, value, attention_mask, dropout, **kwargs):

def patch_attention_functions():
"""
Add the attention_patch decorator to functions in ALL_ATTENTION_FUNCTIONS
"""
Apply attention patching to all transformer attention functions.

This function automatically patches all attention functions registered in
transformers' ALL_ATTENTION_FUNCTIONS to support head-wise key masking.
It enables KVPress compression methods that require head-specific masking
(like AdaKV) to work correctly during text generation.

The patching is applied globally and affects all transformer models loaded
after this function is called. It's automatically called when importing
kvpress to ensure compatibility with head-wise compression methods.

Notes
-----
This function modifies the global attention functions in the transformers
library. The modifications do not affect models that don't use head-wise compression (i.e. don't have
module.masked_key_indices).
"""
for name, func in ALL_ATTENTION_FUNCTIONS.items():
ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func)
58 changes: 43 additions & 15 deletions kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,17 @@

class KVPressTextGenerationPipeline(Pipeline):
"""
Pipeline for key-value compression in causal language models.
This pipeline allows you to compress a long prompt using a key-value press
and then generate answers using greedy decoding.
Pipeline for key-value cache compression in causal language models.

Enables efficient processing of long contexts by applying KV cache compression
during pre-filling, then generating answers using greedy decoding.

Example:
```python
pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)
press = SnapKVPress(compression_ratio=0.5)
result = pipeline(context="Long text...", question="A question about the long context.", press=press)
```
"""

def _sanitize_parameters(
Expand All @@ -51,7 +59,11 @@ def _sanitize_parameters(
answer_prefix : str, optional
The prefix to be added to the generated answer.
press : BasePress, optional
The key-value press to use for compression.
The key-value cache compression method to apply during pre-filling.

Accepts any KVPress compression method (SnapKVPress, KnormPress,
ExpectedAttentionPress, BlockPress, AdaKVPress, ComposedPress, etc.).
If None, no compression is applied.
max_new_tokens : int, optional
The maximum number of new tokens to generate for each answer.
max_context_length : int, optional
Expand Down Expand Up @@ -92,13 +104,27 @@ def preprocess(
max_context_length: int,
):
"""
Apply the chat template to the triplet (context, questions, answer_prefix) and tokenize it.
Apply chat template and tokenize the context and questions.

Prepares input text for KV cache compression and generation by applying
appropriate chat templates and tokenizing. Handles models with and without
chat templates.

Parameters
----------
context : str
Long context text to be compressed using the press method.
questions : list[str]
Questions to be asked about the context.
answer_prefix : str
Optional prefix for generated answers.
max_context_length : int
Maximum tokens allowed in context (truncated if exceeded).

Returns
-------
dict[str, GenericTensor]
A dictionary containing the tokenized context (key: "context_ids") and questions (key: "questions_ids").

Dictionary with "context_ids" and "questions_ids" tensors.
"""

# Apply chat template if available
Expand Down Expand Up @@ -140,25 +166,27 @@ def _forward(
cache: Optional[Cache] = None,
):
"""
Forward pass of the kv-press pipeline.
Execute KV cache compression and text generation pipeline.

Performs context compression using the press method during pre-filling,
then generates answers using greedy decoding.

Parameters
----------
input_tensors : dict[str, GenericTensor]
A dictionary containing the tokenized context and questions.
max_new_tokens : int, optional
The maximum number of new tokens to generate for each answer. Defaults to 50.
Tokenized inputs with "context_ids" and "questions_ids".
max_new_tokens : int, default=50
Maximum tokens to generate for each answer.
press : BasePress, optional
The key-value press to use for compression. Defaults to None.
Compression method for context pre-filling. If None, no compression.
cache : Cache, optional
The cache to use for the forward pass. Defaults to None (DynamicCache).
Cache object for forward pass. If None, creates new DynamicCache.

Returns
-------
list[str]
A list of generated answers.
Generated answers for each input question.
"""

context_ids = input_tensors["context_ids"].to(self.model.device)
context_length = context_ids.shape[1]

Expand Down
21 changes: 17 additions & 4 deletions kvpress/presses/adakv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,23 @@
@dataclass
class AdaKVPress(BasePress):
"""
AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer
based on the scores, achieving head-specific compression.
A safeguard is applied to ensure a minimum fraction of KV pairs per head (alpha_safeguard parameter)
This press has been reviewed by Yuan Feng, first author of AdaKV.
AdaKV: Adaptive head-wise KV cache compression.

Performs head-specific compression by selecting top-k tokens across all heads
based on importance scores. Applies safeguards to ensure each head retains
a minimum fraction of tokens.

Based on AdaKV (https://arxiv.org/abs/2407.11550).

Parameters
----------
press : ScorerPress
AdaKVPress and ObservedAttention are currently not supported.
alpha_safeguard : float, default=0.20
Minimum fraction of KV pairs that each head must retain.
Ensures no attention head is compressed too aggressively. Even if tokens
receive low global importance scores, each head retains at least this
fraction of its original tokens.
"""

press: ScorerPress
Expand Down
Loading