Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Finally we provide wrapper presses that can be combined with other presses:
- `PerLayerCompressionPress` ([source](kvpress/presses/per_layer_compression_press.py)): compress each layer with a different compression ratio (experimental)
- `ComposedPress` ([source](kvpress/presses/composed_press.py)): compose multiple presses together by chaining their forward hooks
- `KeyRerotationPress` ([source](kvpress/presses/key_rerotation_press.py)): rerotate pruned keys to have continuous RoPE embeddings
- `ChunkKVPress` ([source](kvpress/presses/chunkkv_press.py), [paper](https://arxiv.org/abs/2502.00299)): compresses by selecting important chunks, preserving semantic coherence
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.

Expand Down Expand Up @@ -175,4 +176,4 @@ with press(model):

However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.

</details>
</details>
8 changes: 5 additions & 3 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer

from kvpress import (
CriticalKVPress,
CriticalAdaKVPress,
AdaKVPress,
ChunkKVPress,
CriticalAdaKVPress,
CriticalKVPress,
DuoAttentionPress,
ExpectedAttentionPress,
KnormPress,
ObservedAttentionPress,
Expand All @@ -28,7 +30,6 @@
StreamingLLMPress,
ThinKPress,
TOVAPress,
DuoAttentionPress,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,6 +65,7 @@
"think": ThinKPress(),
"tova": TOVAPress(),
"duo_attention": DuoAttentionPress(),
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
}


Expand Down
6 changes: 4 additions & 2 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.chunkkv_press import ChunkKVPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.knorm_press import KnormPress
Expand All @@ -20,8 +23,6 @@
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.criticalkv_press import CriticalKVPress, CriticalAdaKVPress
from kvpress.presses.duo_attention_press import DuoAttentionPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()
Expand All @@ -47,4 +48,5 @@
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
]
1 change: 0 additions & 1 deletion kvpress/presses/chunk_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def compress(
assert attentions is None, "ChunkPress does not support attentions."

kv_len = keys.shape[2]

indices = []
for i in range(0, kv_len, self.chunk_length):
chunk_scores = self.press.score(
Expand Down
112 changes: 112 additions & 0 deletions kvpress/presses/chunkkv_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import torch
from torch import nn

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


@dataclass
class ChunkKVPress(BasePress):
"""
Wrapper class for any ScorerPress.
First calculates global scores using the ScorerPress,
then selects tokens chunk by chunk based on these global scores.
This method was proposed in
ChunkKV: Semantic-Preserving KV Cache Compression for Efficient Long-Context LLM Inference
https://arxiv.org/abs/2502.00299
"""

press: ScorerPress
chunk_length: int = 20

def __post_init__(self):
assert isinstance(self.press, ScorerPress), "ChunkKVPress requires a ScorerPress as input"

@property
def compression_ratio(self):
return self.press.compression_ratio

@compression_ratio.setter
def compression_ratio(self, value):
self.press.compression_ratio = value

def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:

if self.press.compression_ratio == 0:
return keys, values

assert attentions is None, "ChunkPress does not support attentions."

kv_len = keys.shape[2]

# 1. Calculate global scores first
global_scores = self.press.score(
module,
hidden_states,
keys,
values,
attentions,
kwargs,
)

# 2. Calculate actual number of complete chunks and remaining tokens
num_complete_chunks = kv_len // self.chunk_length
remaining_tokens = kv_len % self.chunk_length

# If we have no complete chunks, delegate to the underlying scorer press
if num_complete_chunks == 0:
return self.press.compress(module, hidden_states, keys, values, attentions, kwargs)

# Reshape complete chunks for score calculation
if num_complete_chunks > 0:
main_scores = global_scores[..., : num_complete_chunks * self.chunk_length]
main_chunk_scores = main_scores.sum(dim=1).view(-1, num_complete_chunks, self.chunk_length)
main_chunk_scores = main_chunk_scores.mean(dim=-1)
else:
main_chunk_scores = torch.empty((global_scores.shape[0], 0), device=global_scores.device)

# Handle remaining tokens if any
if remaining_tokens > 0:
remaining_scores = global_scores[..., -remaining_tokens:]
remaining_chunk_score = remaining_scores.sum(dim=1).mean(dim=-1, keepdim=True)
chunk_scores = torch.cat([main_chunk_scores, remaining_chunk_score], dim=-1)
else:
chunk_scores = main_chunk_scores

# 3. Calculate number of chunks to keep
n_chunks_kept = max(1, int((num_complete_chunks + (remaining_tokens > 0)) * (1 - self.press.compression_ratio)))
top_chunks = chunk_scores.topk(n_chunks_kept, dim=-1)

# 4. Create indices for selected chunks
indices = []
for chunk_idx in top_chunks.indices[0]:
if chunk_idx < num_complete_chunks:
# For complete chunks
start_idx = chunk_idx * self.chunk_length
chunk_indices = torch.arange(start_idx, start_idx + self.chunk_length, device=keys.device)
else:
# For the remaining partial chunk
chunk_indices = torch.arange(num_complete_chunks * self.chunk_length, kv_len, device=keys.device)
indices.append(chunk_indices)

indices = torch.cat(indices).sort()[0]
indices = indices.view(1, 1, -1, 1).expand(keys.shape[0], keys.shape[1], -1, module.head_dim)

# 5. Use gather to collect selected keys and values
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()

return keys, values
4 changes: 2 additions & 2 deletions kvpress/presses/criticalkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from transformers.models.llama.modeling_llama import repeat_kv

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.scorer_press import ScorerPress

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +49,7 @@ def vwl1norm(values, module):
# Future kernel fusion optimization could eliminate this intermediate variables to enhance performance.
head_WoV_norm_list = []
for head in range(V.size(1)):
head_WoV = V[: , head, : , ...].matmul(Wo[head, ...].unsqueeze(0))
head_WoV = V[:, head, :, ...].matmul(Wo[head, ...].unsqueeze(0))
head_WoV_norm = torch.norm(head_WoV, p=1, dim=-1)
head_WoV_norm_list.append(head_WoV_norm)

Expand Down
9 changes: 4 additions & 5 deletions kvpress/presses/duo_attention_press.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from io import StringIO
from dataclasses import dataclass, field
from contextlib import contextmanager
from dataclasses import dataclass, field
from io import StringIO

import torch
import requests # type: ignore[import-untyped]
import numpy as np
import requests # type: ignore[import-untyped]
import torch

from kvpress.presses.base_press import BasePress


PATTERNS_DICT = {
"togethercomputer/Llama-2-7B-32K-Instruct": "Llama-2-7B-32K-Instruct/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
"gradientai//Llama-3-8B-Instruct-Gradient-1048k": "Llama-3-8B-Instruct-Gradient-1048k/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
Expand Down
2 changes: 1 addition & 1 deletion tests/default_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from kvpress import (
DuoAttentionPress,
ExpectedAttentionPress,
KnormPress,
RandomPress,
Expand All @@ -12,7 +13,6 @@
StreamingLLMPress,
ThinKPress,
TOVAPress,
DuoAttentionPress,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/presses/test_duo_attention_press.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kvpress.presses.duo_attention_press import DuoAttentionPress, PATTERNS_DICT
from kvpress.presses.duo_attention_press import PATTERNS_DICT, DuoAttentionPress


def test_load_attention_pattern():
Expand Down
23 changes: 19 additions & 4 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
from transformers import DynamicCache

from kvpress import (
CriticalKVPress,
CriticalAdaKVPress,
AdaKVPress,
ChunkKVPress,
ChunkPress,
ComposedPress,
CriticalAdaKVPress,
CriticalKVPress,
KeyRerotationPress,
KnormPress,
ObservedAttentionPress,
ScorerPress,
SnapKVPress,
ThinKPress,
)
from tests.default_presses import default_presses
Expand All @@ -43,9 +45,22 @@ def test_chunk_press(unit_test_model): # noqa: F811
assert cache.get_seq_length() == 128


def test_chunkkv_press(unit_test_model): # noqa: F811
press = SnapKVPress(compression_ratio=0.5)
for chunk_length in [2, 4, 8, 128]:
composed_press = ChunkKVPress(press=press, chunk_length=chunk_length)
with composed_press(unit_test_model):
input_ids = torch.randint(0, 1024, (1, 256))
cache = DynamicCache()
unit_test_model(input_ids, past_key_values=cache).past_key_values
assert cache.get_seq_length() == 128


@pytest.mark.parametrize("press_dict", default_presses)
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress,
CriticalKVPress, CriticalAdaKVPress])
@pytest.mark.parametrize(
"wrapper_press",
[None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress, CriticalKVPress, CriticalAdaKVPress],
)
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
cls = press_dict["cls"]
for kwargs in press_dict["kwargs"]:
Expand Down