Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
- `LeverageScorePress` ([source](kvpress/presses/leverage_press.py), [paper](https://arxiv.org/abs/2507.08143)): evicts tokens based on approximate statistical leverage (i.e we preserve outliers in the key space).
- `CompactorPress` ([source](kvpress/presses/compactor_press.py), [paper](https://arxiv.org/abs/2507.08143)): blends `NonCausalAttnPress` and `LeverageScorePress` based on the compression_ratio.
- `CURPress` ([source](kvpress/presses/cur_press.py), [paper](https://arxiv.org/abs/2509.15038)): prune keys and values based on the CUR decomposition using approximate leverage scores.
- `KVzapPress` ([source](kvpress/presses/kvzap/kvzap_press.py), [paper](https://arxiv.org/abs/2601.07891), [training](kvzap)): approximate KVzip+ using a fast surrogate model. To be used in conjunction with the `ThresholdPress`.
- `KVzapPress` ([source](kvpress/presses/kvzap/kvzap_press.py), [paper](https://arxiv.org/abs/2601.07891), [training](kvzap)): approximate KVzip+ using a fast surrogate model. To be used in conjunction with the `DMSPress`.

Some presses rely on a different logic:
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/abs/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
Expand All @@ -150,7 +150,7 @@ Finally we provide wrapper presses that can be combined with other presses:
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
- `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allows for compression during decoding, see decoding section in this README.
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allows to compress both during prefilling and during decoding.
- `ThresholdPress` ([source](kvpress/presses/threshold_press.py)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True).
- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True).

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down
8 changes: 4 additions & 4 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ObservedAttentionPress,
ScorerPress,
ThinKPress,
ThresholdPress,
DMSPress,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -256,10 +256,10 @@ def _setup_press(self):
if isinstance(press, DuoAttentionPress):
press.head_compression_ratio = compression_ratio
logger.info(f"Set DuoAttentionPress head_compression_ratio to {compression_ratio}")
elif isinstance(press, ThresholdPress):
assert self.config.threshold is not None, "threshold must be set for ThresholdPress"
elif isinstance(press, DMSPress):
assert self.config.threshold is not None, "threshold must be set for DMSPress"
press.threshold = self.config.threshold
logger.info(f"Set ThresholdPress threshold to {press.threshold}")
logger.info(f"Set DMSPress threshold to {press.threshold}")
elif isinstance(press, ComposedPress):
for ps in press.presses:
if isinstance(ps, ThinKPress):
Expand Down
2 changes: 1 addition & 1 deletion evaluation/evaluate_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ data_dir: "4096" # Subdirectory of the dataset
press_name: "knorm" # see PRESS_REGISTRY in evaluate_registry.py
compression_ratio: 0.5 # Compression ratio for the press (0.0 to 1.0)
key_channel_compression_ratio: null # For ThinKPress and ComposedPress (0.0 to 1.0)
threshold: null # For ThresholdPress
threshold: null # For DMSPress

fraction: 1.0 # Fraction of dataset to evaluate (0.0 to 1.0), for quick testing
max_new_tokens: null # Maximum new tokens to generate (null = use dataset default)
Expand Down
6 changes: 3 additions & 3 deletions evaluation/evaluate_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThresholdPress,
DMSPress,
ThinKPress,
TOVAPress,
CURPress,
Expand Down Expand Up @@ -86,8 +86,8 @@
"keydiff": KeyDiffPress(),
"kvzip": KVzipPress(),
"kvzip_plus": KVzipPress(kvzip_plus_normalization=True),
"kvzap_linear": ThresholdPress(press=KVzapPress(model_type="linear")),
"kvzap_mlp": ThresholdPress(press=KVzapPress(model_type="mlp")),
"kvzap_linear": DMSPress(press=KVzapPress(model_type="linear")),
"kvzap_mlp": DMSPress(press=KVzapPress(model_type="mlp")),
"kvzap_mlp_head": KVzapPress(model_type="mlp"),
"kvzap_mlp_layer": AdaKVPress(KVzapPress(model_type="mlp")),
"lagkv": LagKVPress(),
Expand Down
4 changes: 2 additions & 2 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.threshold_press import ThresholdPress
from kvpress.presses.dms_press import DMSPress
from kvpress.presses.tova_press import TOVAPress

# Patch the attention functions to support head-wise compression
Expand Down Expand Up @@ -80,5 +80,5 @@
"LeverageScorePress",
"NonCausalAttnPress",
"KVzapPress",
"ThresholdPress",
"DMSPress",
]
4 changes: 2 additions & 2 deletions kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
from kvpress.presses.threshold_press import ThresholdPress
from kvpress.presses.dms_press import DMSPress

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -224,7 +224,7 @@ def _forward(

# We only perform decoding compression if the press is a decoding or prefill decoding press
perform_decoding_compression = press is not None and isinstance(press, (DecodingPress, PrefillDecodingPress))
if isinstance(press, ThresholdPress):
if isinstance(press, DMSPress):
perform_decoding_compression = press.decoding
with press(self.model) if perform_decoding_compression else contextlib.nullcontext():
# Greedy decoding for each question
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@


@dataclass
class ThresholdPress(BasePress):
class DMSPress(BasePress):
"""
Based on Dynamic Memory Sparsification (DMS, https://arxiv.org/abs/2506.05345) inference.
Wraps a ScorerPress and evicts keys/values with scores below a given threshold.

Unlike most presses that use a fixed compression_ratio, ThresholdPress uses a score threshold
Unlike most presses that use a fixed compression_ratio, DMSPress uses a score threshold
to determine which KV pairs to evict. This allows for adaptive compression where the actual
compression ratio depends on the input content.

Expand Down
2 changes: 1 addition & 1 deletion kvpress/presses/kvzap_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class KVzapPress(ScorerPress):
KVzap (https://arxiv.org/abs/2601.07891) is a fast approximation of KVzip that works
in both prefilling and decoding. It applies a lightweight surrogate model to the hidden
states to predict importance scores for every KV pair.
KVzapPress is designed to be used in conjunction with the ThresholdPress
KVzapPress is designed to be used in conjunction with the DMSPress
model_type can be "linear" or "mlp".
"""

Expand Down
10 changes: 5 additions & 5 deletions kvzap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
[![KVzap collection](https://img.shields.io/badge/🤗%20Hugging%20Face-Collection-orange)](https://huggingface.co/collections/nvidia/kvzap)
[![arXiv](https://img.shields.io/badge/arXiv-2601.07891-b31b1b.svg)](https://arxiv.org/abs/2601.07891)

[KVzap](https://arxiv.org/abs/2601.07891) is a fast approximation of [KVzip](https://arxiv.org/abs/2505.23416) that works in both prefilling and decoding. It applies a lightweight surrogate model to the hidden states to predict importance scores, and removes the KV pairs with a score below a given threshold.
[KVzap](https://arxiv.org/abs/2601.07891) is a fast approximation of [KVzip](https://arxiv.org/abs/2505.23416) that works in both prefilling and decoding. It applies a lightweight surrogate model to the hidden states to predict importance scores, and removes the KV pairs with a score below a given threshold, following the Dynamic Memory Sparsification ([DMS](https://arxiv.org/abs/2506.05345)) inference strategy.

## Usage

KVzap is designed to be used by combining the `KVzapPress` and the `ThresholdPress` from kvpress:
KVzap is designed to be used by combining the `KVzapPress` and the `DMSPress` from kvpress:

```python
import requests
from transformers import pipeline
from kvpress import KVzapPress, ThresholdPress
from kvpress import KVzapPress, DMSPress

model = "Qwen/Qwen3-8B"
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", dtype="auto")
press = ThresholdPress(KVzapPress(model_type="mlp"), threshold=-4)
press = DMSPress(KVzapPress(model_type="mlp"), threshold=-4)

# Prefilling compression only, thinking disabled
press.decoding = False
Expand All @@ -32,7 +32,7 @@ answer = pipe(prompt, press=press, enable_thinking=True, max_new_tokens=2000)["a
print(f"Compression ratio: {press.compression_ratio:.2%}\nAnswer: {answer}")
```

The `KVzapPress` inherits from the `ScorerPress` class and only predicts the scores for every KV pair. The `ThresholdPress` then prunes the KV pairs with a score below a given threshold, rather than using a fixed compression ratio.
The `KVzapPress` inherits from the `ScorerPress` class and only predicts the scores for every KV pair. The `DMSPress` then prunes the KV pairs with a score below a given threshold, rather than using a fixed compression ratio.

Supported base models are provided in the [KVzap collection](https://huggingface.co/collections/nvidia/kvzap) but can easily be extended to any other model following the instructions in the [training section](#training).

Expand Down
8 changes: 4 additions & 4 deletions kvzap/evaluate_aime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

from kvpress import KVzapPress, ThresholdPress
from kvpress import KVzapPress, DMSPress


def calculate_metrics(df):
Expand Down Expand Up @@ -56,11 +56,11 @@ def evaluate(
"""

# Create press
press: ThresholdPress | type[nullcontext[None]]
press: DMSPress | type[nullcontext[None]]
if kvzap_model_type == "no_press":
press = nullcontext
else:
press = ThresholdPress(
press = DMSPress(
KVzapPress(model_type=kvzap_model_type),
threshold=threshold,
decoding=True,
Expand All @@ -86,7 +86,7 @@ def evaluate(
)
answer = tokenizer.decode(output_tokens[0, tokens.shape[1] :])
df.loc[idx, "predicted_answer"] = answer
if isinstance(press, ThresholdPress):
if isinstance(press, DMSPress):
df.loc[idx, "compression_ratio"] = press.compression_ratio
else:
df.loc[idx, "compression_ratio"] = 0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "kvpress"
version = "0.4.1"
version = "0.4.2"
description = "Efficiently compress the KV cache of any pretrained transformer"
authors = [
{ name = "Simon Jegou" },
Expand Down
8 changes: 4 additions & 4 deletions tests/presses/test_head_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from transformers import DynamicCache

from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress, RandomPress, ThresholdPress
from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress, RandomPress, DMSPress
from tests.fixtures import unit_test_model, kv_press_unit_test_pipeline # noqa: F401


Expand Down Expand Up @@ -59,9 +59,9 @@ def test_head_compression(unit_test_model, press, compression_ratio, layerwise):
assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2 # tolerate small differences


def test_threshold_press_compression_ratio(kv_press_unit_test_pipeline): # noqa: F811
"""Test that ThresholdPress.compression_ratio matches the actual masked percentage."""
press = ThresholdPress(
def test_dms_press_compression_ratio(kv_press_unit_test_pipeline): # noqa: F811
"""Test that DMSPress.compression_ratio matches the actual masked percentage."""
press = DMSPress(
press=RandomPress(),
threshold=0.5,
sliding_window_size=0,
Expand Down
8 changes: 4 additions & 4 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ObservedAttentionPress,
ScorerPress,
SnapKVPress,
ThresholdPress,
DMSPress,
ThinKPress,
)
from tests.default_presses import default_presses
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_chunkkv_press(unit_test_model): # noqa: F811
ChunkPress,
CriticalKVPress,
CriticalAdaKVPress,
ThresholdPress,
DMSPress,
],
)
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
Expand All @@ -89,8 +89,8 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
press = wrapper_press(press=press)
elif issubclass(wrapper_press, ChunkPress):
press = ChunkPress(press=press, chunk_length=24)
elif issubclass(wrapper_press, ThresholdPress):
press = ThresholdPress(press=press, threshold=-0.5, sliding_window_size=32)
elif issubclass(wrapper_press, DMSPress):
press = DMSPress(press=press, threshold=-0.5, sliding_window_size=32)

# TODO: Handle post_init_from_model differently
if hasattr(press, "post_init_from_model"):
Expand Down
Loading