Skip to content

Commit 3871dde

Browse files
authored
Add KeyDiffPress (#86)
1 parent 97408ee commit 3871dde

File tree

7 files changed

+188
-1
lines changed

7 files changed

+188
-1
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
7777
- `QFilterPress` ([source](kvpress/presses/qfilter_press.py), [paper](https://arxiv.org/abs/2503.02812)): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.
7878
- `PyramidKVPress` ([source](kvpress/presses/pyramidkv_press.py), [paper](https://arxiv.org/abs/2406.02069)): maintain pyramid-like cache sizes, allocating more cache budget to lower layers and less to higher layers
7979
- `LagKVPress` ([source](kvpress/presses/lagkv_press.py), [paper](https://arxiv.org/abs/2504.04704)): leverage on the KV lag-relative information to compress. It's query free, attention-weight free, and flash-attention compatible.
80+
- `KeyDiffPress` ([source](kvpress/presses/keydiff_press.py), [paper](https://arxiv.org/abs/2504.15364)): evicts tokens based solely on key similarity.
8081

8182
Some presses rely on a different logic:
8283
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
@@ -92,7 +93,7 @@ Finally we provide wrapper presses that can be combined with other presses:
9293
- `ChunkKVPress` ([source](kvpress/presses/chunkkv_press.py), [paper](https://arxiv.org/abs/2502.00299)): compresses by selecting important chunks, preserving semantic coherence
9394
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
9495
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.
95-
96+
- `BlockPress` ([source](kvpress/presses/keydiff_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
9697

9798
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
9899

evaluation/evaluate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
StreamingLLMPress,
3838
ThinKPress,
3939
TOVAPress,
40+
BlockPress,
41+
KeyDiffPress,
4042
)
4143

4244
logger = logging.getLogger(__name__)
@@ -84,6 +86,8 @@
8486
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
8587
"pyramidkv": PyramidKVPress(),
8688
"finch": FinchPress(),
89+
"keydiff": KeyDiffPress(),
90+
"block_keydiff": BlockPress(press=KeyDiffPress(), block_size=128),
8791
}
8892

8993

kvpress/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from kvpress.presses.finch_press import FinchPress
2929
from kvpress.presses.lagkv_press import LagKVPress
3030
from kvpress.presses.base_press import SUPPORTED_MODELS
31+
from kvpress.presses.block_press import BlockPress
32+
from kvpress.presses.keydiff_press import KeyDiffPress
3133

3234
# Patch the attention functions to support head-wise compression
3335
patch_attention_functions()
@@ -58,4 +60,6 @@
5860
"PyramidKVPress",
5961
"FinchPress",
6062
"LagKVPress",
63+
"BlockPress",
64+
"KeyDiffPress",
6165
]

kvpress/presses/block_press.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from dataclasses import dataclass
5+
6+
import torch
7+
from torch import nn
8+
9+
from kvpress.presses.base_press import BasePress
10+
from kvpress.presses.scorer_press import ScorerPress
11+
12+
13+
@dataclass
14+
class BlockPress(BasePress):
15+
"""
16+
Simulates block prompt processing described in the KeyDiff (https://arxiv.org/abs/2504.15364).
17+
Segments input sequence into non-overlapping blocks and compresses iteratively.
18+
Keeps limited memory overhead for long context inference.
19+
"""
20+
21+
press: ScorerPress
22+
block_size: int = 128
23+
24+
def __post_init__(self):
25+
assert isinstance(self.press, ScorerPress), "BlockPress requires a ScorerPress"
26+
27+
@property
28+
def compression_ratio(self):
29+
return self.press.compression_ratio
30+
31+
@compression_ratio.setter
32+
def compression_ratio(self, value):
33+
self.press.compression_ratio = value
34+
35+
def compress(
36+
self,
37+
module: nn.Module,
38+
hidden_states: torch.Tensor,
39+
keys: torch.Tensor,
40+
values: torch.Tensor,
41+
attentions: torch.Tensor,
42+
kwargs: dict,
43+
) -> tuple[torch.Tensor, torch.Tensor]:
44+
if self.press.compression_ratio == 0:
45+
return keys, values
46+
47+
assert attentions is None, "BlockPress does not support attentions."
48+
49+
bsz, num_key_value_heads, q_len, head_dim = keys.shape
50+
51+
block_size = self.block_size if self.block_size < q_len else q_len
52+
n_kept = int(q_len * (1 - self.compression_ratio))
53+
54+
kept_indices = torch.arange(n_kept, device=keys.device).expand(bsz, num_key_value_heads, -1)
55+
56+
# Reshape hidden states to match the kept_indices
57+
states = hidden_states.view(bsz, q_len, num_key_value_heads, -1).transpose(1, 2)
58+
59+
for i in range(n_kept, q_len, block_size):
60+
end = min(i + block_size, q_len)
61+
current_indices = torch.arange(i, end, device=keys.device).expand(bsz, num_key_value_heads, -1)
62+
current_indices = torch.cat([kept_indices, current_indices], dim=-1)
63+
64+
# Gather hidden states for the selected indices, then restore the shape
65+
# Check tests/presses/test_block_press.py for correctness verification of gathered hidden states
66+
current_states = states.gather(2, current_indices.unsqueeze(-1).expand(-1, -1, -1, states.shape[-1]))
67+
current_states = current_states.transpose(1, 2).reshape(bsz, -1, hidden_states.shape[-1])
68+
69+
scores = self.press.score(
70+
module,
71+
current_states,
72+
keys.gather(2, current_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)),
73+
values.gather(2, current_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)),
74+
attentions,
75+
kwargs,
76+
)
77+
topk_indices = scores.topk(n_kept, dim=-1).indices
78+
kept_indices = current_indices.gather(-1, topk_indices)
79+
80+
kept_indices = kept_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
81+
keys = keys.gather(2, kept_indices).contiguous()
82+
values = values.gather(2, kept_indices).contiguous()
83+
84+
return keys, values

kvpress/presses/keydiff_press.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from dataclasses import dataclass
5+
6+
import torch
7+
from torch import nn
8+
from torch.nn import functional as F
9+
10+
from kvpress.presses.scorer_press import ScorerPress
11+
12+
13+
@dataclass
14+
class KeyDiffPress(ScorerPress):
15+
"""
16+
KeyDiff (https://arxiv.org/abs/2504.15364) evict tokens based solely on key similarity.
17+
"""
18+
def score(
19+
self,
20+
module: nn.Module,
21+
hidden_states: torch.Tensor,
22+
keys: torch.Tensor,
23+
values: torch.Tensor,
24+
attentions: torch.Tensor,
25+
kwargs,
26+
) -> torch.Tensor:
27+
anchor = F.normalize(keys, p=2, dim=-1).mean(dim=2, keepdim=True)
28+
return -F.cosine_similarity(keys, anchor, dim=-1)

tests/default_presses.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
QFilterPress,
1717
PyramidKVPress,
1818
LagKVPress,
19+
KeyDiffPress,
1920
)
2021

2122

@@ -65,4 +66,5 @@ def load_attention_pattern(model):
6566
{"compression_ratio": 0.8, "n_sink": 16, "lag_size": 128}
6667
],
6768
},
69+
{"cls": KeyDiffPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
6870
]

tests/presses/test_block_press.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from dataclasses import dataclass
5+
6+
import torch
7+
import torch.nn as nn
8+
from transformers import DynamicCache
9+
10+
from kvpress.presses.scorer_press import ScorerPress
11+
from kvpress.presses.block_press import BlockPress
12+
13+
from tests.fixtures import unit_test_model # noqa: F401
14+
15+
16+
@dataclass
17+
class HiddenStatesPress(ScorerPress): # dummy press using hidden states
18+
19+
def score(
20+
self,
21+
module: nn.Module,
22+
hidden_states: torch.Tensor,
23+
keys: torch.Tensor,
24+
values: torch.Tensor,
25+
attentions: torch.Tensor,
26+
kwargs,
27+
) -> torch.Tensor:
28+
return hidden_states.mean(-1).unsqueeze(1).expand_as(keys.norm(dim=-1))
29+
30+
31+
def test_block_press_is_streaming_top_k(unit_test_model): # noqa: F811
32+
"""
33+
Test that BlockPress correctly applies the compression ratio and keeps the output consistent.
34+
"""
35+
press = HiddenStatesPress(compression_ratio=0.5)
36+
generator = torch.Generator().manual_seed(0)
37+
input_ids = torch.randint(0, 1024, (1, 256), generator=generator)
38+
keys_hash = []
39+
values_hash = []
40+
41+
for block_size in [2, 4, 8, 128, 256]:
42+
composed_press = BlockPress(press=press, block_size=block_size)
43+
with composed_press(unit_test_model):
44+
cache = DynamicCache()
45+
unit_test_model(input_ids, past_key_values=cache).past_key_values
46+
assert cache.get_seq_length() == 128
47+
keys = cache.key_cache
48+
values = cache.value_cache
49+
keys_hash.append(torch.cat(keys).sum().item())
50+
values_hash.append(torch.cat(values).sum().item())
51+
52+
with press(unit_test_model):
53+
cache = DynamicCache()
54+
unit_test_model(input_ids, past_key_values=cache).past_key_values
55+
assert cache.get_seq_length() == 128
56+
keys = cache.key_cache
57+
values = cache.value_cache
58+
keys_hash.append(torch.cat(keys).sum().item())
59+
values_hash.append(torch.cat(values).sum().item())
60+
61+
keys_tensor = torch.tensor(keys_hash)
62+
values_tensor = torch.tensor(values_hash)
63+
assert torch.allclose(keys_tensor, keys_tensor[-1])
64+
assert torch.allclose(values_tensor, values_tensor[-1])

0 commit comments

Comments
 (0)