From 595c972fc7b39b8e250ceca0f89329cfa95cfdba Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 21 Jan 2026 08:18:33 +0000 Subject: [PATCH 1/3] Rename ThresholdPress to DMSPress Signed-off-by: SimJeg --- README.md | 4 ++-- evaluation/evaluate.py | 8 ++++---- evaluation/evaluate_config.yaml | 2 +- evaluation/evaluate_registry.py | 6 +++--- kvpress/__init__.py | 4 ++-- kvpress/pipeline.py | 4 ++-- kvpress/presses/{threshold_press.py => dms_press.py} | 5 +++-- kvpress/presses/kvzap_press.py | 2 +- kvzap/README.md | 8 ++++---- kvzap/evaluate_aime.py | 8 ++++---- tests/presses/test_head_compression.py | 8 ++++---- tests/presses/test_presses.py | 8 ++++---- 12 files changed, 34 insertions(+), 33 deletions(-) rename kvpress/presses/{threshold_press.py => dms_press.py} (96%) diff --git a/README.md b/README.md index 8acbb121..c647b4ff 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres - `LeverageScorePress` ([source](kvpress/presses/leverage_press.py), [paper](https://arxiv.org/abs/2507.08143)): evicts tokens based on approximate statistical leverage (i.e we preserve outliers in the key space). - `CompactorPress` ([source](kvpress/presses/compactor_press.py), [paper](https://arxiv.org/abs/2507.08143)): blends `NonCausalAttnPress` and `LeverageScorePress` based on the compression_ratio. - `CURPress` ([source](kvpress/presses/cur_press.py), [paper](https://arxiv.org/abs/2509.15038)): prune keys and values based on the CUR decomposition using approximate leverage scores. -- `KVzapPress` ([source](kvpress/presses/kvzap/kvzap_press.py), [paper](https://arxiv.org/abs/2601.07891), [training](kvzap)): approximate KVzip+ using a fast surrogate model. To be used in conjunction with the `ThresholdPress`. +- `KVzapPress` ([source](kvpress/presses/kvzap/kvzap_press.py), [paper](https://arxiv.org/abs/2601.07891), [training](kvzap)): approximate KVzip+ using a fast surrogate model. To be used in conjunction with the `DMSPress`. Some presses rely on a different logic: - `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/abs/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries @@ -150,7 +150,7 @@ Finally we provide wrapper presses that can be combined with other presses: - `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively. - `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allows for compression during decoding, see decoding section in this README. - `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allows to compress both during prefilling and during decoding. -- `ThresholdPress` ([source](kvpress/presses/threshold_press.py)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True). +- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True). For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index b0467492..2267432f 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -28,7 +28,7 @@ ObservedAttentionPress, ScorerPress, ThinKPress, - ThresholdPress, + DMSPress, ) logger = logging.getLogger(__name__) @@ -256,10 +256,10 @@ def _setup_press(self): if isinstance(press, DuoAttentionPress): press.head_compression_ratio = compression_ratio logger.info(f"Set DuoAttentionPress head_compression_ratio to {compression_ratio}") - elif isinstance(press, ThresholdPress): - assert self.config.threshold is not None, "threshold must be set for ThresholdPress" + elif isinstance(press, DMSPress): + assert self.config.threshold is not None, "threshold must be set for DMSPress" press.threshold = self.config.threshold - logger.info(f"Set ThresholdPress threshold to {press.threshold}") + logger.info(f"Set DMSPress threshold to {press.threshold}") elif isinstance(press, ComposedPress): for ps in press.presses: if isinstance(ps, ThinKPress): diff --git a/evaluation/evaluate_config.yaml b/evaluation/evaluate_config.yaml index 442c5135..1edf2ffa 100644 --- a/evaluation/evaluate_config.yaml +++ b/evaluation/evaluate_config.yaml @@ -10,7 +10,7 @@ data_dir: "4096" # Subdirectory of the dataset press_name: "knorm" # see PRESS_REGISTRY in evaluate_registry.py compression_ratio: 0.5 # Compression ratio for the press (0.0 to 1.0) key_channel_compression_ratio: null # For ThinKPress and ComposedPress (0.0 to 1.0) -threshold: null # For ThresholdPress +threshold: null # For DMSPress fraction: 1.0 # Fraction of dataset to evaluate (0.0 to 1.0), for quick testing max_new_tokens: null # Maximum new tokens to generate (null = use dataset default) diff --git a/evaluation/evaluate_registry.py b/evaluation/evaluate_registry.py index 307a234d..2274c43a 100644 --- a/evaluation/evaluate_registry.py +++ b/evaluation/evaluate_registry.py @@ -34,7 +34,7 @@ RandomPress, SnapKVPress, StreamingLLMPress, - ThresholdPress, + DMSPress, ThinKPress, TOVAPress, CURPress, @@ -86,8 +86,8 @@ "keydiff": KeyDiffPress(), "kvzip": KVzipPress(), "kvzip_plus": KVzipPress(kvzip_plus_normalization=True), - "kvzap_linear": ThresholdPress(press=KVzapPress(model_type="linear")), - "kvzap_mlp": ThresholdPress(press=KVzapPress(model_type="mlp")), + "kvzap_linear": DMSPress(press=KVzapPress(model_type="linear")), + "kvzap_mlp": DMSPress(press=KVzapPress(model_type="mlp")), "kvzap_mlp_head": KVzapPress(model_type="mlp"), "kvzap_mlp_layer": AdaKVPress(KVzapPress(model_type="mlp")), "lagkv": LagKVPress(), diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 03934bb4..7f124d4a 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -37,7 +37,7 @@ from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.think_press import ThinKPress -from kvpress.presses.threshold_press import ThresholdPress +from kvpress.presses.dms_press import DMSPress from kvpress.presses.tova_press import TOVAPress # Patch the attention functions to support head-wise compression @@ -80,5 +80,5 @@ "LeverageScorePress", "NonCausalAttnPress", "KVzapPress", - "ThresholdPress", + "DMSPress", ] diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 9eee07e0..0ad369c1 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -16,7 +16,7 @@ from kvpress.presses.finch_press import FinchPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.prefill_decoding_press import PrefillDecodingPress -from kvpress.presses.threshold_press import ThresholdPress +from kvpress.presses.dms_press import DMSPress logger = logging.getLogger(__name__) @@ -224,7 +224,7 @@ def _forward( # We only perform decoding compression if the press is a decoding or prefill decoding press perform_decoding_compression = press is not None and isinstance(press, (DecodingPress, PrefillDecodingPress)) - if isinstance(press, ThresholdPress): + if isinstance(press, DMSPress): perform_decoding_compression = press.decoding with press(self.model) if perform_decoding_compression else contextlib.nullcontext(): # Greedy decoding for each question diff --git a/kvpress/presses/threshold_press.py b/kvpress/presses/dms_press.py similarity index 96% rename from kvpress/presses/threshold_press.py rename to kvpress/presses/dms_press.py index 6618de74..de1636a2 100644 --- a/kvpress/presses/threshold_press.py +++ b/kvpress/presses/dms_press.py @@ -13,11 +13,12 @@ @dataclass -class ThresholdPress(BasePress): +class DMSPress(BasePress): """ + Based on Dynamic Memory Sparsification (DMS, https://arxiv.org/abs/2506.05345) inference. Wraps a ScorerPress and evicts keys/values with scores below a given threshold. - Unlike most presses that use a fixed compression_ratio, ThresholdPress uses a score threshold + Unlike most presses that use a fixed compression_ratio, DMSPress uses a score threshold to determine which KV pairs to evict. This allows for adaptive compression where the actual compression ratio depends on the input content. diff --git a/kvpress/presses/kvzap_press.py b/kvpress/presses/kvzap_press.py index a663409b..e6ca5b82 100644 --- a/kvpress/presses/kvzap_press.py +++ b/kvpress/presses/kvzap_press.py @@ -50,7 +50,7 @@ class KVzapPress(ScorerPress): KVzap (https://arxiv.org/abs/2601.07891) is a fast approximation of KVzip that works in both prefilling and decoding. It applies a lightweight surrogate model to the hidden states to predict importance scores for every KV pair. - KVzapPress is designed to be used in conjunction with the ThresholdPress + KVzapPress is designed to be used in conjunction with the DMSPress model_type can be "linear" or "mlp". """ diff --git a/kvzap/README.md b/kvzap/README.md index b73d276b..19a93158 100644 --- a/kvzap/README.md +++ b/kvzap/README.md @@ -7,16 +7,16 @@ ## Usage -KVzap is designed to be used by combining the `KVzapPress` and the `ThresholdPress` from kvpress: +KVzap is designed to be used by combining the `KVzapPress` and the `DMSPress` from kvpress: ```python import requests from transformers import pipeline -from kvpress import KVzapPress, ThresholdPress +from kvpress import KVzapPress, DMSPress model = "Qwen/Qwen3-8B" pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", dtype="auto") -press = ThresholdPress(KVzapPress(model_type="mlp"), threshold=-4) +press = DMSPress(KVzapPress(model_type="mlp"), threshold=-4) # Prefilling compression only, thinking disabled press.decoding = False @@ -32,7 +32,7 @@ answer = pipe(prompt, press=press, enable_thinking=True, max_new_tokens=2000)["a print(f"Compression ratio: {press.compression_ratio:.2%}\nAnswer: {answer}") ``` -The `KVzapPress` inherits from the `ScorerPress` class and only predicts the scores for every KV pair. The `ThresholdPress` then prunes the KV pairs with a score below a given threshold, rather than using a fixed compression ratio. +The `KVzapPress` inherits from the `ScorerPress` class and only predicts the scores for every KV pair. The `DMSPress` then prunes the KV pairs with a score below a given threshold, rather than using a fixed compression ratio. Supported base models are provided in the [KVzap collection](https://huggingface.co/collections/nvidia/kvzap) but can easily be extended to any other model following the instructions in the [training section](#training). diff --git a/kvzap/evaluate_aime.py b/kvzap/evaluate_aime.py index 65c8a157..2c224440 100644 --- a/kvzap/evaluate_aime.py +++ b/kvzap/evaluate_aime.py @@ -10,7 +10,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset -from kvpress import KVzapPress, ThresholdPress +from kvpress import KVzapPress, DMSPress def calculate_metrics(df): @@ -56,11 +56,11 @@ def evaluate( """ # Create press - press: ThresholdPress | type[nullcontext[None]] + press: DMSPress | type[nullcontext[None]] if kvzap_model_type == "no_press": press = nullcontext else: - press = ThresholdPress( + press = DMSPress( KVzapPress(model_type=kvzap_model_type), threshold=threshold, decoding=True, @@ -86,7 +86,7 @@ def evaluate( ) answer = tokenizer.decode(output_tokens[0, tokens.shape[1] :]) df.loc[idx, "predicted_answer"] = answer - if isinstance(press, ThresholdPress): + if isinstance(press, DMSPress): df.loc[idx, "compression_ratio"] = press.compression_ratio else: df.loc[idx, "compression_ratio"] = 0 diff --git a/tests/presses/test_head_compression.py b/tests/presses/test_head_compression.py index aeb45385..f9138f86 100644 --- a/tests/presses/test_head_compression.py +++ b/tests/presses/test_head_compression.py @@ -4,7 +4,7 @@ import torch from transformers import DynamicCache -from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress, RandomPress, ThresholdPress +from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress, RandomPress, DMSPress from tests.fixtures import unit_test_model, kv_press_unit_test_pipeline # noqa: F401 @@ -59,9 +59,9 @@ def test_head_compression(unit_test_model, press, compression_ratio, layerwise): assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2 # tolerate small differences -def test_threshold_press_compression_ratio(kv_press_unit_test_pipeline): # noqa: F811 - """Test that ThresholdPress.compression_ratio matches the actual masked percentage.""" - press = ThresholdPress( +def test_dms_press_compression_ratio(kv_press_unit_test_pipeline): # noqa: F811 + """Test that DMSPress.compression_ratio matches the actual masked percentage.""" + press = DMSPress( press=RandomPress(), threshold=0.5, sliding_window_size=0, diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 114ad03f..6b19fd94 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -20,7 +20,7 @@ ObservedAttentionPress, ScorerPress, SnapKVPress, - ThresholdPress, + DMSPress, ThinKPress, ) from tests.default_presses import default_presses @@ -69,7 +69,7 @@ def test_chunkkv_press(unit_test_model): # noqa: F811 ChunkPress, CriticalKVPress, CriticalAdaKVPress, - ThresholdPress, + DMSPress, ], ) def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 @@ -89,8 +89,8 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 press = wrapper_press(press=press) elif issubclass(wrapper_press, ChunkPress): press = ChunkPress(press=press, chunk_length=24) - elif issubclass(wrapper_press, ThresholdPress): - press = ThresholdPress(press=press, threshold=-0.5, sliding_window_size=32) + elif issubclass(wrapper_press, DMSPress): + press = DMSPress(press=press, threshold=-0.5, sliding_window_size=32) # TODO: Handle post_init_from_model differently if hasattr(press, "post_init_from_model"): From 46e0324346d172c2ae44745e985e5b164fdb46c1 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 21 Jan 2026 08:35:15 +0000 Subject: [PATCH 2/3] Update version Signed-off-by: SimJeg --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e8426701..94aae499 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kvpress" -version = "0.4.1" +version = "0.4.2" description = "Efficiently compress the KV cache of any pretrained transformer" authors = [ { name = "Simon Jegou" }, From a1c3bcdd1a6937d8af9022a91bcb3ca80a1668f5 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 21 Jan 2026 09:13:26 +0000 Subject: [PATCH 3/3] Update README Signed-off-by: SimJeg --- kvzap/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kvzap/README.md b/kvzap/README.md index 19a93158..81eb03fe 100644 --- a/kvzap/README.md +++ b/kvzap/README.md @@ -3,7 +3,7 @@ [![KVzap collection](https://img.shields.io/badge/🤗%20Hugging%20Face-Collection-orange)](https://huggingface.co/collections/nvidia/kvzap) [![arXiv](https://img.shields.io/badge/arXiv-2601.07891-b31b1b.svg)](https://arxiv.org/abs/2601.07891) -[KVzap](https://arxiv.org/abs/2601.07891) is a fast approximation of [KVzip](https://arxiv.org/abs/2505.23416) that works in both prefilling and decoding. It applies a lightweight surrogate model to the hidden states to predict importance scores, and removes the KV pairs with a score below a given threshold. +[KVzap](https://arxiv.org/abs/2601.07891) is a fast approximation of [KVzip](https://arxiv.org/abs/2505.23416) that works in both prefilling and decoding. It applies a lightweight surrogate model to the hidden states to predict importance scores, and removes the KV pairs with a score below a given threshold, following the Dynamic Memory Sparsification ([DMS](https://arxiv.org/abs/2506.05345)) inference strategy. ## Usage